Files
lerobot/tests/rewards/test_sarm_utils.py
T
Khalil Meftah 8a3d64033f Reward models refactor (#3142)
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes

* refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/

* refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/

* refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py

* refactor(rewards): update imports and delete old reward model locations

* test(rewards): add reward model tests and update existing test imports

* fix(rewards): restore full Classifier and SARM implementations

* test(rewards): restore missing CUDA and mixed precision classifier processor tests

* refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train

* refactor(lerobot_train.py): add missing sampling weight script

* linter + missing files

* add testing for sampl weighter

* revert some useless changes, improve typing

* update docs

* add automatic detection of the progress path

* remove type exp

* improve comment

* fix: move rabc.py to rewards/sarm/ and update import paths

* refactor(imports): update reward model imports to new module structure

* refactor(imports): update reward model imports to reflect new module structure

* refactor(imports): conditionally import pandas based on availability

* feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig

* refactor(policies): remove reward model branches from policy factory and __init__

* refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash

* feat(train): route reward model training through rewards/factory instead of policies/factory

* refactor(train): streamline reward model training logic

* fix(rewards): ensure FileNotFoundError is raised for missing config_file

* refactor(train): update __get_path_fields__ to include reward_model for config loading

* refactor(classifier): remove redundant input normalization in predict_reward method

* fix(train): raise ValueError for non-trainable reward models in train function

* refactor(pretrained_rm): add model card template

* refactor(tests): reward models

* refactor(sarm): update reset method and remove unused action prediction methods

* refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function

* fix(train): raise ValueError for PEFT usage in reward model training

* refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2026-04-28 17:56:24 +02:00

616 lines
22 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
from lerobot.rewards.sarm.sarm_utils import (
apply_rewind_augmentation,
compute_absolute_indices,
compute_tau,
find_stage_and_tau,
normalize_stage_tau,
temporal_proportions_to_breakpoints,
)
class TestProgressLabelsWithModes:
"""End-to-end tests for progress label generation in different modes."""
def test_sparse_mode_single_stage(self):
"""Sparse mode with single stage should give linear progress."""
episode_length = 300
global_names = ["task"]
proportions = {"task": 1.0}
# Test at various frames
for frame in [0, 100, 200, 299]:
stage, tau = find_stage_and_tau(
frame, episode_length, None, None, None, global_names, proportions
)
expected_tau = frame / (episode_length - 1)
assert stage == 0
assert abs(tau - expected_tau) < 1e-5
def test_sparse_mode_multi_stage(self):
"""Sparse mode with multiple stages."""
global_names = ["reach", "grasp", "lift", "place"]
proportions = {"reach": 0.2, "grasp": 0.2, "lift": 0.3, "place": 0.3}
subtask_names = ["reach", "grasp", "lift", "place"]
subtask_starts = [0, 60, 120, 210]
subtask_ends = [59, 119, 209, 299]
# Check stages are correctly identified
stage_at_30, _ = find_stage_and_tau(
30, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage_at_30 == 0
stage_at_90, _ = find_stage_and_tau(
90, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage_at_90 == 1
stage_at_150, _ = find_stage_and_tau(
150, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage_at_150 == 2
def test_dense_mode_more_stages(self):
"""Dense mode should work with more fine-grained stages."""
global_names = ["a", "b", "c", "d", "e", "f", "g", "h"]
proportions = dict.fromkeys(global_names, 1 / 8)
subtask_names = global_names
subtask_starts = [i * 50 for i in range(8)]
subtask_ends = [(i + 1) * 50 - 1 for i in range(8)]
# Each stage should occupy 50 frames
for stage_idx in range(8):
mid_frame = stage_idx * 50 + 25
stage, _ = find_stage_and_tau(
mid_frame, 400, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == stage_idx
class TestComputeAbsoluteIndices:
"""Tests for compute_absolute_indices (bidirectional sampling)."""
def test_no_clamping_when_in_middle(self):
"""When frame is in middle of episode, no clamping should occur."""
frame_idx = 300
ep_start = 0
ep_end = 1000
n_obs_steps = 8
frame_gap = 30
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
# All should be valid (no out of bounds)
assert out_of_bounds.sum() == 0
# Check bidirectional indices: [-120, -90, -60, -30, 0, 30, 60, 90, 120] from center
half_steps = n_obs_steps // 2
expected = (
[frame_idx - frame_gap * i for i in range(half_steps, 0, -1)]
+ [frame_idx]
+ [frame_idx + frame_gap * i for i in range(1, half_steps + 1)]
)
assert indices.tolist() == expected
# Center frame (index 4) should be frame_idx
assert indices[half_steps] == frame_idx
def test_clamping_at_episode_start(self):
"""Early frames should be clamped to episode start."""
frame_idx = 50 # Not enough history for full past window
ep_start = 0
ep_end = 1000
n_obs_steps = 8
frame_gap = 30
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
# Some past frames should be clamped (out_of_bounds = 1)
assert out_of_bounds.sum() > 0
# All indices should be >= ep_start
assert (indices >= ep_start).all()
# Center index should be frame_idx
half_steps = n_obs_steps // 2
assert indices[half_steps] == frame_idx
def test_clamping_at_episode_end(self):
"""Late frames should be clamped to episode end."""
frame_idx = 950 # Not enough future for full window
ep_start = 0
ep_end = 1000
n_obs_steps = 8
frame_gap = 30
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
# Some future frames should be clamped
assert out_of_bounds.sum() > 0
# All indices should be < ep_end
assert (indices < ep_end).all()
# Center index should be frame_idx
half_steps = n_obs_steps // 2
assert indices[half_steps] == frame_idx
def test_sequence_is_monotonic(self):
"""Frame indices should be monotonically increasing."""
for frame_idx in [50, 100, 300, 950]:
indices, _ = compute_absolute_indices(frame_idx, 0, 1000, 8, 30)
# Check monotonic (non-decreasing due to clamping)
diffs = indices[1:] - indices[:-1]
assert (diffs >= 0).all(), f"Non-monotonic at frame {frame_idx}"
class TestComputeTau:
"""Tests for compute_tau (within-subtask progress).
Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
"""
def test_at_start(self):
"""τ should be 0 at subtask start."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_at_end(self):
"""τ should be 1 at subtask end."""
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_at_middle(self):
"""τ should be 0.5 at subtask midpoint."""
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
assert abs(tau - 0.5) < 1e-6
def test_quarter_progress(self):
"""Test τ at 25% through subtask."""
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
assert abs(tau - 0.25) < 1e-6
def test_zero_duration_subtask(self):
"""τ should be 1.0 for zero-duration subtask."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
assert tau == 1.0
def test_clamps_below_zero(self):
"""τ should be clamped to 0 if frame is before subtask."""
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_clamps_above_one(self):
"""τ should be clamped to 1 if frame is after subtask."""
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_float_inputs(self):
"""Test with float frame indices (from interpolation)."""
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
expected = (25.5 - 10.0) / (50.0 - 10.0)
assert abs(tau - expected) < 1e-6
class TestFindStageAndTau:
"""Tests for find_stage_and_tau logic.
This function is the core of progress label computation. It determines
which stage a frame belongs to and the within-stage progress (tau).
"""
def test_single_stage_mode_linear_progress(self):
"""Single-stage mode should give linear progress from 0 to 1."""
episode_length = 100
# Frame 0 -> tau = 0
stage, tau = find_stage_and_tau(0, episode_length, None, None, None, ["task"], {"task": 1.0})
assert stage == 0
assert abs(tau - 0.0) < 1e-6
# Frame 50 -> tau = 0.505 (50/99)
stage, tau = find_stage_and_tau(50, episode_length, None, None, None, ["task"], {"task": 1.0})
assert stage == 0
assert abs(tau - 50 / 99) < 1e-6
# Frame 99 -> tau = 1.0
stage, tau = find_stage_and_tau(99, episode_length, None, None, None, ["task"], {"task": 1.0})
assert stage == 0
assert abs(tau - 1.0) < 1e-6
def test_multi_stage_within_subtask(self):
"""Test finding stage when frame is within a subtask."""
global_names = ["reach", "grasp", "lift"]
proportions = {"reach": 0.3, "grasp": 0.2, "lift": 0.5}
subtask_names = ["reach", "grasp", "lift"]
subtask_starts = [0, 30, 50]
subtask_ends = [29, 49, 99]
# Frame 15 in "reach" stage (index 0)
stage, tau = find_stage_and_tau(
15, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 0
assert abs(tau - 15 / 29) < 1e-6
# Frame 40 in "grasp" stage (index 1)
stage, tau = find_stage_and_tau(
40, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 1
# tau = (40 - 30) / (49 - 30) = 10/19
assert abs(tau - 10 / 19) < 1e-6
# Frame 75 in "lift" stage (index 2)
stage, tau = find_stage_and_tau(
75, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 2
# tau = (75 - 50) / (99 - 50) = 25/49
assert abs(tau - 25 / 49) < 1e-6
def test_frame_at_subtask_boundaries(self):
"""Test frames exactly at subtask boundaries."""
global_names = ["a", "b"]
proportions = {"a": 0.5, "b": 0.5}
subtask_names = ["a", "b"]
subtask_starts = [0, 50]
subtask_ends = [49, 99]
# Frame at start of first subtask
stage, tau = find_stage_and_tau(
0, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 0
assert tau == 0.0
# Frame at end of first subtask
stage, tau = find_stage_and_tau(
49, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 0
assert tau == 1.0
# Frame at start of second subtask
stage, tau = find_stage_and_tau(
50, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 1
assert tau == 0.0
def test_frame_after_last_subtask(self):
"""Frames after last subtask should return last stage with high tau."""
global_names = ["a", "b"]
proportions = {"a": 0.5, "b": 0.5}
subtask_names = ["a", "b"]
subtask_starts = [0, 30]
subtask_ends = [29, 59]
# Frame 80 is after last subtask
stage, tau = find_stage_and_tau(
80, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 1 # Last stage
assert tau == 0.999 # Nearly complete
class TestEndToEndProgressLabeling:
"""End-to-end tests for progress label computation using normalize_stage_tau."""
def test_consistent_semantic_meaning(self):
"""Test that same subtask completion maps to same progress across trajectories.
This is the key semantic property: "end of subtask 1" should always
mean the same progress value regardless of trajectory speed.
"""
proportions = [0.3, 0.5, 0.2]
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
tau_fast = compute_tau(30, 0, 30) # = 1.0
y_fast = normalize_stage_tau(0 + tau_fast, temporal_proportions=proportions)
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
tau_slow = compute_tau(90, 0, 90) # = 1.0
y_slow = normalize_stage_tau(0 + tau_slow, temporal_proportions=proportions)
# Both should map to same progress (0.3 = end of subtask 1)
assert abs(y_fast - y_slow) < 1e-6
assert abs(y_fast - 0.3) < 1e-6
def test_monotonic_within_subtask(self):
"""Test that progress is monotonically increasing within a subtask."""
proportions = [0.4, 0.6]
prev_y = -1
for tau in np.linspace(0, 1, 11):
y = normalize_stage_tau(0 + tau, temporal_proportions=proportions)
assert y > prev_y or (tau == 0 and y == 0)
prev_y = y
def test_continuous_across_subtasks(self):
"""Test that progress is continuous at subtask boundaries."""
proportions = [0.3, 0.5, 0.2]
# End of subtask 0 (stage=0, tau=1.0) -> stage.tau = 1.0
y_end_0 = normalize_stage_tau(0 + 1.0, temporal_proportions=proportions)
# Start of subtask 1 (stage=1, tau=0.0) -> stage.tau = 1.0
y_start_1 = normalize_stage_tau(1 + 0.0, temporal_proportions=proportions)
# Should be equal (P_1 = 0.3)
assert abs(y_end_0 - y_start_1) < 1e-6
# End of subtask 1 (stage=1, tau=1.0) -> stage.tau = 2.0
y_end_1 = normalize_stage_tau(1 + 1.0, temporal_proportions=proportions)
# Start of subtask 2 (stage=2, tau=0.0) -> stage.tau = 2.0
y_start_2 = normalize_stage_tau(2 + 0.0, temporal_proportions=proportions)
# Should be equal (P_2 = 0.8)
assert abs(y_end_1 - y_start_2) < 1e-6
class TestTemporalProportionsToBreakpoints:
"""Tests for temporal_proportions_to_breakpoints.
Converts temporal proportions to cumulative breakpoints for normalization.
Example: [0.3, 0.5, 0.2] -> [0.0, 0.3, 0.8, 1.0]
"""
def test_basic_conversion(self):
"""Test basic conversion from proportions to breakpoints."""
proportions = [0.3, 0.5, 0.2]
breakpoints = temporal_proportions_to_breakpoints(proportions)
assert breakpoints is not None
assert len(breakpoints) == 4
assert breakpoints[0] == 0.0
assert abs(breakpoints[1] - 0.3) < 1e-6
assert abs(breakpoints[2] - 0.8) < 1e-6
assert breakpoints[3] == 1.0
def test_dict_input(self):
"""Test with dict input."""
proportions = {"a": 0.25, "b": 0.25, "c": 0.5}
breakpoints = temporal_proportions_to_breakpoints(proportions)
assert breakpoints is not None
assert len(breakpoints) == 4
assert breakpoints[0] == 0.0
assert breakpoints[-1] == 1.0
def test_dict_with_subtask_names_order(self):
"""Test that subtask_names determines order for dict input."""
proportions = {"c": 0.5, "a": 0.2, "b": 0.3} # Dict order
subtask_names = ["a", "b", "c"] # Different order
breakpoints = temporal_proportions_to_breakpoints(proportions, subtask_names)
# Breakpoints should follow subtask_names order: a=0.2, b=0.3, c=0.5
assert abs(breakpoints[1] - 0.2) < 1e-6 # a
assert abs(breakpoints[2] - 0.5) < 1e-6 # a + b = 0.5
assert breakpoints[3] == 1.0 # a + b + c = 1.0
def test_uniform_proportions(self):
"""Test with uniform proportions."""
proportions = [0.25, 0.25, 0.25, 0.25]
breakpoints = temporal_proportions_to_breakpoints(proportions)
expected = [0.0, 0.25, 0.5, 0.75, 1.0]
for i, (bp, exp) in enumerate(zip(breakpoints, expected, strict=True)):
assert abs(bp - exp) < 1e-6, f"Breakpoint {i} mismatch"
def test_none_input(self):
"""Test that None input returns None."""
result = temporal_proportions_to_breakpoints(None)
assert result is None
def test_normalization(self):
"""Test that non-normalized proportions are normalized."""
# Proportions sum to 2.0, not 1.0
proportions = [0.6, 1.0, 0.4]
breakpoints = temporal_proportions_to_breakpoints(proportions)
# Should be normalized: [0.3, 0.5, 0.2] -> [0, 0.3, 0.8, 1.0]
assert breakpoints[-1] == 1.0
assert abs(breakpoints[1] - 0.3) < 1e-6
class TestNormalizeStageTau:
"""Tests for normalize_stage_tau.
Normalizes stage+tau values to [0, 1] using breakpoints.
"""
def test_linear_fallback(self):
"""Test linear normalization when only num_stages is provided."""
# 4 stages, linear: [0, 0.25, 0.5, 0.75, 1.0]
# Stage 0 start
assert normalize_stage_tau(0.0, num_stages=4) == 0.0
# Stage 0 end / Stage 1 start
assert abs(normalize_stage_tau(1.0, num_stages=4) - 0.25) < 1e-6
# Stage 1 middle
assert abs(normalize_stage_tau(1.5, num_stages=4) - 0.375) < 1e-6
# Stage 3 end
assert normalize_stage_tau(4.0, num_stages=4) == 1.0
def test_with_custom_breakpoints(self):
"""Test with custom breakpoints."""
# Non-linear breakpoints
breakpoints = [0.0, 0.1, 0.5, 1.0] # 3 stages
# Stage 0: maps [0, 1) to [0.0, 0.1)
assert abs(normalize_stage_tau(0.5, breakpoints=breakpoints) - 0.05) < 1e-6
# Stage 1: maps [1, 2) to [0.1, 0.5)
assert abs(normalize_stage_tau(1.5, breakpoints=breakpoints) - 0.3) < 1e-6
# Stage 2: maps [2, 3) to [0.5, 1.0)
assert abs(normalize_stage_tau(2.5, breakpoints=breakpoints) - 0.75) < 1e-6
def test_with_temporal_proportions(self):
"""Test with temporal proportions (auto-computed breakpoints)."""
proportions = {"a": 0.2, "b": 0.3, "c": 0.5}
subtask_names = ["a", "b", "c"]
# Stage 0 end should map to 0.2
result = normalize_stage_tau(1.0, temporal_proportions=proportions, subtask_names=subtask_names)
assert abs(result - 0.2) < 1e-6
# Stage 1 end should map to 0.5
result = normalize_stage_tau(2.0, temporal_proportions=proportions, subtask_names=subtask_names)
assert abs(result - 0.5) < 1e-6
def test_tensor_input(self):
"""Test with tensor input."""
x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0])
breakpoints = [0.0, 0.3, 0.8, 1.0] # 3 stages
result = normalize_stage_tau(x, breakpoints=breakpoints)
assert isinstance(result, torch.Tensor)
assert result.shape == x.shape
assert abs(result[0].item() - 0.0) < 1e-6
assert abs(result[2].item() - 0.3) < 1e-6 # End of stage 0
assert abs(result[4].item() - 0.8) < 1e-6 # End of stage 1
def test_clamping(self):
"""Test that output is clamped to [0, 1]."""
# Below 0
assert normalize_stage_tau(-0.5, num_stages=4) == 0.0
# Above num_stages
assert normalize_stage_tau(5.0, num_stages=4) == 1.0
def test_batch_tensor(self):
"""Test with batched tensor."""
x = torch.tensor([[0.0, 1.0, 2.0], [0.5, 1.5, 2.5]]) # (2, 3)
result = normalize_stage_tau(x, num_stages=3)
assert result.shape == (2, 3)
assert (result >= 0).all()
assert (result <= 1).all()
def test_requires_one_of_inputs(self):
"""Test that at least one input method is required."""
with pytest.raises(ValueError):
normalize_stage_tau(1.0)
class TestRewindAugmentation:
"""Tests for rewind augmentation logic with bidirectional observation sampling.
Rewind appends frames before the earliest observation frame, going backwards.
With bidirectional sampling centered at frame_idx:
- Earliest obs frame = frame_idx - half_steps * frame_gap
- Rewind goes backwards from that point
"""
def test_rewind_indices_go_backwards_from_earliest_obs(self):
"""Rewind indices should go backwards from earliest observation frame."""
frame_idx = 300 # Center of bidirectional window
ep_start = 0
n_obs_steps = 4 # half_steps = 2
frame_gap = 30
# Earliest obs frame = 300 - 2*30 = 240
# Rewind goes backwards: 210, 180
rewind_step, rewind_indices = apply_rewind_augmentation(
frame_idx,
ep_start,
n_obs_steps=n_obs_steps,
max_rewind_steps=2,
frame_gap=frame_gap,
rewind_step=2,
)
assert rewind_step == 2
assert len(rewind_indices) == 2
# First rewind frame is closest to obs window, second is further back
assert rewind_indices[0] == 210 # 240 - 30
assert rewind_indices[1] == 180 # 240 - 60
assert rewind_indices[0] > rewind_indices[1], "Rewind should be descending"
def test_rewind_goes_backward_through_history(self):
"""Rewind frames should go backward before the observation window."""
frame_idx = 450 # Center of bidirectional window
ep_start = 0
n_obs_steps = 8 # half_steps = 4
frame_gap = 30
# Earliest obs frame = 450 - 4*30 = 330
# Rewind from 330: [300, 270, 240]
rewind_step, rewind_indices = apply_rewind_augmentation(
frame_idx,
ep_start,
n_obs_steps=n_obs_steps,
max_rewind_steps=4,
frame_gap=frame_gap,
rewind_step=3,
)
assert rewind_step == 3
expected = [300, 270, 240] # Going backwards from 330
assert rewind_indices == expected
def test_no_rewind_when_obs_window_at_episode_start(self):
"""No rewind when observation window reaches episode start."""
frame_idx = 120 # Center of window
ep_start = 0
n_obs_steps = 8 # half_steps = 4
frame_gap = 30
# Earliest obs frame = 120 - 4*30 = 0 (at episode start)
rewind_step, rewind_indices = apply_rewind_augmentation(
frame_idx, ep_start, n_obs_steps=n_obs_steps, max_rewind_steps=4, frame_gap=frame_gap
)
# No room for rewind
assert rewind_step == 0
assert rewind_indices == []
def test_rewind_targets_are_decreasing(self):
"""Progress targets for rewind frames should be decreasing."""
# Simulate progress values
obs_progress = [0.1, 0.2, 0.3, 0.4, 0.5] # Forward progress
# Rewind reverses progress
rewind_indices = [4, 3, 2] # Go backwards through indices
rewind_progress = [obs_progress[i] for i in rewind_indices]
# Should be decreasing
for i in range(len(rewind_progress) - 1):
assert rewind_progress[i] > rewind_progress[i + 1]