Reward models refactor (#3142)

* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes

* refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/

* refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/

* refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py

* refactor(rewards): update imports and delete old reward model locations

* test(rewards): add reward model tests and update existing test imports

* fix(rewards): restore full Classifier and SARM implementations

* test(rewards): restore missing CUDA and mixed precision classifier processor tests

* refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train

* refactor(lerobot_train.py): add missing sampling weight script

* linter + missing files

* add testing for sampl weighter

* revert some useless changes, improve typing

* update docs

* add automatic detection of the progress path

* remove type exp

* improve comment

* fix: move rabc.py to rewards/sarm/ and update import paths

* refactor(imports): update reward model imports to new module structure

* refactor(imports): update reward model imports to reflect new module structure

* refactor(imports): conditionally import pandas based on availability

* feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig

* refactor(policies): remove reward model branches from policy factory and __init__

* refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash

* feat(train): route reward model training through rewards/factory instead of policies/factory

* refactor(train): streamline reward model training logic

* fix(rewards): ensure FileNotFoundError is raised for missing config_file

* refactor(train): update __get_path_fields__ to include reward_model for config loading

* refactor(classifier): remove redundant input normalization in predict_reward method

* fix(train): raise ValueError for non-trainable reward models in train function

* refactor(pretrained_rm): add model card template

* refactor(tests): reward models

* refactor(sarm): update reset method and remove unused action prediction methods

* refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function

* fix(train): raise ValueError for PEFT usage in reward model training

* refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Khalil Meftah
2026-04-28 17:56:24 +02:00
committed by GitHub
parent 03ee50e08f
commit 8a3d64033f
37 changed files with 2091 additions and 381 deletions
@@ -1,153 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE, REWARD
from tests.utils import skip_if_package_missing
def test_classifier_output():
output = ClassifierOutput(
logits=torch.tensor([1, 2, 3]),
probabilities=torch.tensor([0.1, 0.2, 0.3]),
hidden_states=None,
)
assert (
f"{output}"
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
)
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_binary_classifier_with_default_params():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
config.input_features = {
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
config.num_cameras = 1
classifier = Classifier(config)
batch_size = 10
input = {
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
REWARD: torch.randint(low=0, high=2, size=(batch_size,)).float(),
}
images, labels = classifier.extract_images_and_labels(input)
assert len(images) == 1
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
assert labels.shape == torch.Size([batch_size])
output = classifier.predict(images)
assert output is not None
assert output.logits.size() == torch.Size([batch_size])
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 256])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_multiclass_classifier():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
num_classes = 5
config = RewardClassifierConfig()
config.input_features = {
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
}
config.num_cameras = 1
config.num_classes = num_classes
classifier = Classifier(config)
batch_size = 10
input = {
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
REWARD: torch.rand((batch_size, num_classes)),
}
images, labels = classifier.extract_images_and_labels(input)
assert len(images) == 1
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
assert labels.shape == torch.Size([batch_size, num_classes])
output = classifier.predict(images)
assert output is not None
assert output.logits.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
assert output.hidden_states.shape == torch.Size([batch_size, 256])
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_default_device():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
assert config.device == "cpu"
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_explicit_device_setup():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig(device="cpu")
assert config.device == "cpu"
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")
-694
View File
@@ -1,694 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
pytest.importorskip("faker")
from unittest.mock import MagicMock, patch
import numpy as np
import pandas as pd
import pytest
import torch
from lerobot.types import TransitionKey
class MockDatasetMeta:
"""Mock dataset metadata for testing processor."""
def __init__(self, episodes: list[dict]):
self._episodes = episodes
@property
def episodes(self):
"""Return episodes as a mock object with to_pandas() method."""
mock = MagicMock()
mock.__len__ = lambda s: len(self._episodes)
mock.__getitem__ = lambda s, idx: self._episodes[idx]
mock.to_pandas = lambda: pd.DataFrame(self._episodes)
return mock
class MockConfig:
"""Mock SARMConfig for testing processor methods."""
def __init__(
self,
n_obs_steps: int = 8,
max_rewind_steps: int = 4,
frame_gap: int = 30,
sparse_subtask_names: list = None,
sparse_temporal_proportions: list = None,
dense_subtask_names: list = None,
dense_temporal_proportions: list = None,
image_key: str = "observation.images.top",
state_key: str = "observation.state",
max_state_dim: int = 32,
device: str = None,
rewind_probability: float = 0.8,
language_perturbation_probability: float = 0.2,
annotation_mode: str = "dual",
clip_batch_size: int = 64,
text_dim: int = 512,
):
self.n_obs_steps = n_obs_steps
self.max_rewind_steps = max_rewind_steps
self.frame_gap = frame_gap
self.sparse_subtask_names = sparse_subtask_names or ["task"]
self.sparse_temporal_proportions = sparse_temporal_proportions or [1.0]
self.dense_subtask_names = dense_subtask_names
self.dense_temporal_proportions = dense_temporal_proportions
self.uses_dual_heads = annotation_mode in ["dense_only", "dual"]
self.image_key = image_key
self.state_key = state_key
self.max_state_dim = max_state_dim
self.device = device
self.rewind_probability = rewind_probability
self.language_perturbation_probability = language_perturbation_probability
self.annotation_mode = annotation_mode
self.clip_batch_size = clip_batch_size
self.text_dim = text_dim
# Compute observation delta indices (same as config: bidirectional)
half_steps = self.n_obs_steps // 2
past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
obs_deltas = past_deltas + [0] + future_deltas
rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
self.observation_delta_indices = obs_deltas + rewind_deltas
@property
def num_frames(self) -> int:
return 1 + self.n_obs_steps + self.max_rewind_steps
class TestSARMEncodingProcessorStepEndToEnd:
"""End-to-end test for SARMEncodingProcessorStep with dummy batch data."""
@pytest.fixture
def mock_clip_model(self):
"""Mock CLIP model to avoid loading real weights."""
with (
patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls,
patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
):
# Mock the CLIP model - return embeddings based on input batch size
mock_model = MagicMock()
def get_image_features_side_effect(**kwargs):
pixel_values = kwargs.get("pixel_values")
batch_size = pixel_values.shape[0] if pixel_values is not None else 1
return torch.randn(batch_size, 512)
mock_model.get_image_features.side_effect = get_image_features_side_effect
mock_model.get_text_features.return_value = torch.randn(1, 512)
mock_model.to.return_value = mock_model
mock_model_cls.from_pretrained.return_value = mock_model
# Mock the CLIP processor - return tensors based on input images
mock_processor = MagicMock()
def processor_side_effect(images=None, **kwargs):
num_images = len(images) if images is not None else 1
return {
"pixel_values": torch.randn(num_images, 3, 224, 224),
}
mock_processor.side_effect = processor_side_effect
# Mock tokenizer for text encoding
mock_processor.tokenizer.return_value = {
"input_ids": torch.ones(1, 77, dtype=torch.long),
"attention_mask": torch.ones(1, 77, dtype=torch.long),
}
mock_processor_cls.from_pretrained.return_value = mock_processor
yield mock_model, mock_processor
@pytest.fixture
def processor_with_mocks(self, mock_clip_model):
"""Create a processor with mocked CLIP and dataset metadata for dual mode."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
# Dual mode config with both sparse and dense annotations
config = MockConfig(
n_obs_steps=8,
max_rewind_steps=4,
frame_gap=30,
rewind_probability=0.0, # Disable for deterministic test
language_perturbation_probability=0.0, # Disable for deterministic test
annotation_mode="dual",
sparse_subtask_names=["reach", "grasp", "lift"],
sparse_temporal_proportions=[0.3, 0.4, 0.3],
dense_subtask_names=["approach", "contact", "close_gripper", "lift_up"],
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
)
# Create mock dataset metadata with one episode of 300 frames
# Include annotation columns for dual mode
episodes = [
{
"dataset_from_index": 0,
"dataset_to_index": 300,
"task": "pick up the cube",
"sparse_subtask_names": ["reach", "grasp", "lift"],
"sparse_subtask_start_frames": [0, 90, 210],
"sparse_subtask_end_frames": [90, 210, 300],
"dense_subtask_names": ["approach", "contact", "close_gripper", "lift_up"],
"dense_subtask_start_frames": [0, 75, 150, 225],
"dense_subtask_end_frames": [75, 150, 225, 300],
}
]
dataset_meta = MockDatasetMeta(episodes)
processor = SARMEncodingProcessorStep(
config=config,
dataset_meta=dataset_meta,
)
processor.train(True) # Use train() method, not direct assignment
return processor, config
def test_call_with_single_frame_batch(self, processor_with_mocks):
"""Test processor __call__ with a single-frame batch."""
processor, config = processor_with_mocks
# Create dummy input transition
batch_size = 1
num_frames = config.num_frames # 13 frames (9 obs + 4 rewind)
# Image: (T, C, H, W) format as expected by processor
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
# State: (T, D) format
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
transition = {
TransitionKey.OBSERVATION: {
config.image_key: dummy_image,
config.state_key: dummy_state,
},
TransitionKey.COMPLEMENTARY_DATA: {
"index": 150, # Middle of episode
"episode_index": 0,
"task": "pick up the cube",
},
}
# Run processor
result = processor(transition)
# Verify output structure
obs = result[TransitionKey.OBSERVATION]
# Check video features exist and have correct shape
assert "video_features" in obs
video_features = obs["video_features"]
assert video_features.shape[0] == batch_size
assert video_features.shape[1] == num_frames
assert video_features.shape[2] == 512 # CLIP embedding dim
# Check state features exist and have correct shape
assert "state_features" in obs
state_features = obs["state_features"]
assert state_features.shape[0] == batch_size
assert state_features.shape[1] == num_frames
assert state_features.shape[2] == config.max_state_dim # Padded to max_state_dim
# Check text features exist and have correct shape
assert "text_features" in obs
text_features = obs["text_features"]
assert text_features.shape[0] == batch_size
assert text_features.shape[1] == 512 # CLIP embedding dim
# Check lengths tensor
assert "lengths" in obs
lengths = obs["lengths"]
assert lengths.shape[0] == batch_size
assert lengths.dtype == torch.int32
# Check sparse_targets exist
assert "sparse_targets" in obs
sparse_targets = obs["sparse_targets"]
assert sparse_targets.shape == (batch_size, num_frames)
# All targets should be in [0, max_stages] range (stage.tau format)
assert (sparse_targets >= 0).all()
# Check dense_targets exist (for dual mode)
assert "dense_targets" in obs
dense_targets = obs["dense_targets"]
assert dense_targets.shape == (batch_size, num_frames)
assert (dense_targets >= 0).all()
def test_call_with_batched_input(self, mock_clip_model):
"""Test processor __call__ with a batched input (multiple frames) in dual mode."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
config = MockConfig(
n_obs_steps=8,
max_rewind_steps=4,
frame_gap=30,
rewind_probability=0.0,
language_perturbation_probability=0.0,
annotation_mode="dual",
sparse_subtask_names=["reach", "grasp"],
sparse_temporal_proportions=[0.5, 0.5],
dense_subtask_names=["step1", "step2", "step3"],
dense_temporal_proportions=[0.33, 0.34, 0.33],
)
# Two episodes with different lengths, each with sparse+dense annotations
episodes = [
{
"dataset_from_index": 0,
"dataset_to_index": 200,
"task": "task A",
"sparse_subtask_names": ["reach", "grasp"],
"sparse_subtask_start_frames": [0, 100],
"sparse_subtask_end_frames": [100, 200],
"dense_subtask_names": ["step1", "step2", "step3"],
"dense_subtask_start_frames": [0, 66, 133],
"dense_subtask_end_frames": [66, 133, 200],
},
{
"dataset_from_index": 200,
"dataset_to_index": 500,
"task": "task B",
"sparse_subtask_names": ["reach", "grasp"],
"sparse_subtask_start_frames": [200, 350],
"sparse_subtask_end_frames": [350, 500],
"dense_subtask_names": ["step1", "step2", "step3"],
"dense_subtask_start_frames": [200, 300, 400],
"dense_subtask_end_frames": [300, 400, 500],
},
]
dataset_meta = MockDatasetMeta(episodes)
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
processor.train(True)
batch_size = 2
num_frames = config.num_frames
# Image: (B, T, C, H, W) format
dummy_image = np.random.rand(batch_size, num_frames, 3, 224, 224).astype(np.float32)
dummy_state = np.random.rand(batch_size, num_frames, 6).astype(np.float32)
transition = {
TransitionKey.OBSERVATION: {
config.image_key: dummy_image,
config.state_key: dummy_state,
},
TransitionKey.COMPLEMENTARY_DATA: {
"index": np.array([100, 350]), # One frame from each episode
"episode_index": np.array([0, 1]),
"task": ["task A", "task B"],
},
}
result = processor(transition)
obs = result[TransitionKey.OBSERVATION]
# Verify batch dimension is preserved for all outputs
assert obs["video_features"].shape[0] == batch_size
assert obs["state_features"].shape[0] == batch_size
assert obs["lengths"].shape[0] == batch_size
assert obs["sparse_targets"].shape[0] == batch_size
assert obs["dense_targets"].shape[0] == batch_size # Dual mode has dense targets
def test_targets_increase_with_progress(self, mock_clip_model):
"""Test that both sparse and dense targets increase as frame index progresses."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
config = MockConfig(
n_obs_steps=8,
max_rewind_steps=4,
frame_gap=30,
rewind_probability=0.0,
language_perturbation_probability=0.0,
annotation_mode="dual",
sparse_subtask_names=["phase1", "phase2"],
sparse_temporal_proportions=[0.5, 0.5],
dense_subtask_names=["a", "b", "c", "d"],
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
)
episodes = [
{
"dataset_from_index": 0,
"dataset_to_index": 300,
"task": "test task",
"sparse_subtask_names": ["phase1", "phase2"],
"sparse_subtask_start_frames": [0, 150],
"sparse_subtask_end_frames": [150, 300],
"dense_subtask_names": ["a", "b", "c", "d"],
"dense_subtask_start_frames": [0, 75, 150, 225],
"dense_subtask_end_frames": [75, 150, 225, 300],
}
]
dataset_meta = MockDatasetMeta(episodes)
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
processor.train(True)
num_frames = config.num_frames
# Test at early, middle, and late points in episode
frame_indices = [30, 150, 270]
sparse_center_targets = []
dense_center_targets = []
for frame_idx in frame_indices:
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
transition = {
TransitionKey.OBSERVATION: {
config.image_key: dummy_image,
config.state_key: dummy_state,
},
TransitionKey.COMPLEMENTARY_DATA: {
"index": frame_idx,
"episode_index": 0,
"task": "test task",
},
}
result = processor(transition)
obs = result[TransitionKey.OBSERVATION]
# Get target at center frame (index 4 in 9-frame observation window)
sparse_center_targets.append(obs["sparse_targets"][0, 4].item())
dense_center_targets.append(obs["dense_targets"][0, 4].item())
# Both sparse and dense targets should increase with frame index
assert sparse_center_targets[0] < sparse_center_targets[2], (
f"Early sparse target ({sparse_center_targets[0]}) should be < late ({sparse_center_targets[2]})"
)
assert dense_center_targets[0] < dense_center_targets[2], (
f"Early dense target ({dense_center_targets[0]}) should be < late ({dense_center_targets[2]})"
)
def test_progress_labels_exact_values(self, mock_clip_model):
"""Test that progress labels (stage.tau) are computed correctly for known positions."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
# Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
config = MockConfig(
n_obs_steps=8,
max_rewind_steps=4,
frame_gap=10, # Smaller gap for easier calculation
rewind_probability=0.0,
language_perturbation_probability=0.0,
annotation_mode="dual",
sparse_subtask_names=["A", "B"],
sparse_temporal_proportions=[0.5, 0.5],
dense_subtask_names=["d1", "d2", "d3", "d4"],
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
)
# Episode: frames 0-99, sparse stages at [0-49], [50-99]
# Dense stages at [0-24], [25-49], [50-74], [75-99]
episodes = [
{
"dataset_from_index": 0,
"dataset_to_index": 100,
"task": "test",
"sparse_subtask_names": ["A", "B"],
"sparse_subtask_start_frames": [0, 50],
"sparse_subtask_end_frames": [50, 100],
"dense_subtask_names": ["d1", "d2", "d3", "d4"],
"dense_subtask_start_frames": [0, 25, 50, 75],
"dense_subtask_end_frames": [25, 50, 75, 100],
}
]
dataset_meta = MockDatasetMeta(episodes)
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
processor.train(True)
num_frames = config.num_frames
# Test at frame 50 (center of episode)
# With frame_gap=10, n_obs_steps=8:
# obs indices around frame 50: [10, 20, 30, 40, 50, 60, 70, 80, 90] (9 frames)
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
transition = {
TransitionKey.OBSERVATION: {
config.image_key: dummy_image,
config.state_key: dummy_state,
},
TransitionKey.COMPLEMENTARY_DATA: {
"index": 50,
"episode_index": 0,
"task": "test",
},
}
result = processor(transition)
obs = result[TransitionKey.OBSERVATION]
sparse_targets = obs["sparse_targets"][0] # (13,)
dense_targets = obs["dense_targets"][0] # (13,)
# First 9 frames are observation frames, last 4 are rewind placeholders (zeros when no rewind)
# Check that obs frames have non-zero targets
obs_sparse = sparse_targets[:9]
obs_dense = dense_targets[:9]
# Verify targets are monotonically increasing for observation frames
for i in range(1, 9):
assert obs_sparse[i] >= obs_sparse[i - 1], (
f"Sparse targets should be monotonic: {obs_sparse[i - 1].item():.3f} -> {obs_sparse[i].item():.3f}"
)
assert obs_dense[i] >= obs_dense[i - 1], (
f"Dense targets should be monotonic: {obs_dense[i - 1].item():.3f} -> {obs_dense[i].item():.3f}"
)
# Rewind slots should be zero when rewind is disabled
rewind_targets = sparse_targets[9:]
assert (rewind_targets == 0).all(), "Rewind slots should be zero when rewind is disabled"
# Check stage transitions: frame 50 is at boundary of sparse stage A->B
# Center frame (index 4) corresponds to actual frame 50
center_sparse = obs_sparse[4].item()
# At frame 50, sparse stage B starts, so target should be ~1.0 (stage 1 + tau 0)
assert 0.9 <= center_sparse <= 1.1, (
f"At sparse boundary, target should be ~1.0, got {center_sparse:.3f}"
)
def test_rewind_augmentation_applied(self, mock_clip_model):
"""Test that rewind augmentation correctly extends sequence and generates targets."""
import random
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
config = MockConfig(
n_obs_steps=8,
max_rewind_steps=4,
frame_gap=10,
rewind_probability=1.0, # Always apply rewind
language_perturbation_probability=0.0,
annotation_mode="dual",
sparse_subtask_names=["A", "B"],
sparse_temporal_proportions=[0.5, 0.5],
dense_subtask_names=["d1", "d2"],
dense_temporal_proportions=[0.5, 0.5],
)
episodes = [
{
"dataset_from_index": 0,
"dataset_to_index": 200,
"task": "test",
"sparse_subtask_names": ["A", "B"],
"sparse_subtask_start_frames": [0, 100],
"sparse_subtask_end_frames": [100, 200],
"dense_subtask_names": ["d1", "d2"],
"dense_subtask_start_frames": [0, 100],
"dense_subtask_end_frames": [100, 200],
}
]
dataset_meta = MockDatasetMeta(episodes)
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
processor.train(True)
num_frames = config.num_frames # 13
# Test at frame 150 (center of bidirectional window)
# With n_obs_steps=8, half_steps=4, frame_gap=10:
# - Earliest obs frame = 150 - 4*10 = 110
# - Rewind can go back from 110 to frames like 100, 90, 80, 70
# - History available = 110 - 0 = 110, so max rewind = 110/10 = 11 (capped at 4)
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
transition = {
TransitionKey.OBSERVATION: {
config.image_key: dummy_image,
config.state_key: dummy_state,
},
TransitionKey.COMPLEMENTARY_DATA: {
"index": 150,
"episode_index": 0,
"task": "test",
},
}
# Seed random for reproducibility
random.seed(42)
result = processor(transition)
obs = result[TransitionKey.OBSERVATION]
lengths = obs["lengths"][0].item()
sparse_targets = obs["sparse_targets"][0]
# With rewind_probability=1.0 and enough history, lengths should be > 9 (9 obs + some rewind)
assert lengths > 9, f"With rewind enabled, lengths should be > 9, got {lengths}"
assert lengths <= num_frames, f"Lengths should not exceed total frames {num_frames}, got {lengths}"
# Rewind targets should be non-zero for frames within valid length
n_obs_frames = 9
rewind_count = lengths - n_obs_frames
if rewind_count > 0:
# Check that rewind frames have targets
rewind_targets = sparse_targets[n_obs_frames : n_obs_frames + rewind_count]
# Rewind frames are from BEFORE the earliest obs frame (110)
# These frames (100, 90, 80, 70) are earlier in the episode
earliest_obs_target = sparse_targets[0].item() # Frame 110
# Rewind targets should be less than earliest obs (they're from earlier frames)
for i, rt in enumerate(rewind_targets):
assert rt.item() < earliest_obs_target, (
f"Rewind target {i} ({rt.item():.3f}) should be < earliest obs ({earliest_obs_target:.3f})"
)
# Rewind targets should be decreasing (going further back in time)
for i in range(1, len(rewind_targets)):
assert rewind_targets[i] <= rewind_targets[i - 1], (
f"Rewind targets should decrease: {rewind_targets[i - 1].item():.3f} -> {rewind_targets[i].item():.3f}"
)
def test_full_sequence_target_consistency(self, mock_clip_model):
"""Test that the full sequence of targets is consistent with frame positions."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.policies.sarm.sarm_utils import find_stage_and_tau
config = MockConfig(
n_obs_steps=8,
max_rewind_steps=4,
frame_gap=10,
rewind_probability=0.0,
language_perturbation_probability=0.0,
annotation_mode="dual",
sparse_subtask_names=["s1", "s2", "s3"],
sparse_temporal_proportions=[0.33, 0.34, 0.33],
dense_subtask_names=["d1", "d2"],
dense_temporal_proportions=[0.5, 0.5],
)
# 3 sparse stages: [0-33), [33-66), [66-99]
# 2 dense stages: [0-50), [50-100)
episodes = [
{
"dataset_from_index": 0,
"dataset_to_index": 100,
"task": "test",
"sparse_subtask_names": ["s1", "s2", "s3"],
"sparse_subtask_start_frames": [0, 33, 66],
"sparse_subtask_end_frames": [33, 66, 100],
"dense_subtask_names": ["d1", "d2"],
"dense_subtask_start_frames": [0, 50],
"dense_subtask_end_frames": [50, 100],
}
]
dataset_meta = MockDatasetMeta(episodes)
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
processor.train(True)
num_frames = config.num_frames
# Test at frame 50 (middle of episode)
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
transition = {
TransitionKey.OBSERVATION: {
config.image_key: dummy_image,
config.state_key: dummy_state,
},
TransitionKey.COMPLEMENTARY_DATA: {
"index": 50,
"episode_index": 0,
"task": "test",
},
}
result = processor(transition)
obs = result[TransitionKey.OBSERVATION]
sparse_targets = obs["sparse_targets"][0]
dense_targets = obs["dense_targets"][0]
# Manually compute expected targets for observation frames
# With frame_gap=10, n_obs_steps=8, center at 50:
# obs frames: [10, 20, 30, 40, 50, 60, 70, 80, 90]
expected_obs_frames = [10, 20, 30, 40, 50, 60, 70, 80, 90]
sparse_names = ["s1", "s2", "s3"]
sparse_starts = [0, 33, 66]
sparse_ends = [33, 66, 100]
sparse_props = {"s1": 0.33, "s2": 0.34, "s3": 0.33}
dense_names = ["d1", "d2"]
dense_starts = [0, 50]
dense_ends = [50, 100]
dense_props = {"d1": 0.5, "d2": 0.5}
for i, frame in enumerate(expected_obs_frames):
expected_sparse = find_stage_and_tau(
frame,
100,
sparse_names,
sparse_starts,
sparse_ends,
sparse_names,
sparse_props,
return_combined=True,
)
expected_dense = find_stage_and_tau(
frame,
100,
dense_names,
dense_starts,
dense_ends,
dense_names,
dense_props,
return_combined=True,
)
actual_sparse = sparse_targets[i].item()
actual_dense = dense_targets[i].item()
assert abs(actual_sparse - expected_sparse) < 0.01, (
f"Frame {frame}: sparse mismatch {actual_sparse:.3f} vs expected {expected_sparse:.3f}"
)
assert abs(actual_dense - expected_dense) < 0.01, (
f"Frame {frame}: dense mismatch {actual_dense:.3f} vs expected {expected_dense:.3f}"
)
-615
View File
@@ -1,615 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
from lerobot.policies.sarm.sarm_utils import (
apply_rewind_augmentation,
compute_absolute_indices,
compute_tau,
find_stage_and_tau,
normalize_stage_tau,
temporal_proportions_to_breakpoints,
)
class TestProgressLabelsWithModes:
"""End-to-end tests for progress label generation in different modes."""
def test_sparse_mode_single_stage(self):
"""Sparse mode with single stage should give linear progress."""
episode_length = 300
global_names = ["task"]
proportions = {"task": 1.0}
# Test at various frames
for frame in [0, 100, 200, 299]:
stage, tau = find_stage_and_tau(
frame, episode_length, None, None, None, global_names, proportions
)
expected_tau = frame / (episode_length - 1)
assert stage == 0
assert abs(tau - expected_tau) < 1e-5
def test_sparse_mode_multi_stage(self):
"""Sparse mode with multiple stages."""
global_names = ["reach", "grasp", "lift", "place"]
proportions = {"reach": 0.2, "grasp": 0.2, "lift": 0.3, "place": 0.3}
subtask_names = ["reach", "grasp", "lift", "place"]
subtask_starts = [0, 60, 120, 210]
subtask_ends = [59, 119, 209, 299]
# Check stages are correctly identified
stage_at_30, _ = find_stage_and_tau(
30, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage_at_30 == 0
stage_at_90, _ = find_stage_and_tau(
90, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage_at_90 == 1
stage_at_150, _ = find_stage_and_tau(
150, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage_at_150 == 2
def test_dense_mode_more_stages(self):
"""Dense mode should work with more fine-grained stages."""
global_names = ["a", "b", "c", "d", "e", "f", "g", "h"]
proportions = dict.fromkeys(global_names, 1 / 8)
subtask_names = global_names
subtask_starts = [i * 50 for i in range(8)]
subtask_ends = [(i + 1) * 50 - 1 for i in range(8)]
# Each stage should occupy 50 frames
for stage_idx in range(8):
mid_frame = stage_idx * 50 + 25
stage, _ = find_stage_and_tau(
mid_frame, 400, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == stage_idx
class TestComputeAbsoluteIndices:
"""Tests for compute_absolute_indices (bidirectional sampling)."""
def test_no_clamping_when_in_middle(self):
"""When frame is in middle of episode, no clamping should occur."""
frame_idx = 300
ep_start = 0
ep_end = 1000
n_obs_steps = 8
frame_gap = 30
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
# All should be valid (no out of bounds)
assert out_of_bounds.sum() == 0
# Check bidirectional indices: [-120, -90, -60, -30, 0, 30, 60, 90, 120] from center
half_steps = n_obs_steps // 2
expected = (
[frame_idx - frame_gap * i for i in range(half_steps, 0, -1)]
+ [frame_idx]
+ [frame_idx + frame_gap * i for i in range(1, half_steps + 1)]
)
assert indices.tolist() == expected
# Center frame (index 4) should be frame_idx
assert indices[half_steps] == frame_idx
def test_clamping_at_episode_start(self):
"""Early frames should be clamped to episode start."""
frame_idx = 50 # Not enough history for full past window
ep_start = 0
ep_end = 1000
n_obs_steps = 8
frame_gap = 30
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
# Some past frames should be clamped (out_of_bounds = 1)
assert out_of_bounds.sum() > 0
# All indices should be >= ep_start
assert (indices >= ep_start).all()
# Center index should be frame_idx
half_steps = n_obs_steps // 2
assert indices[half_steps] == frame_idx
def test_clamping_at_episode_end(self):
"""Late frames should be clamped to episode end."""
frame_idx = 950 # Not enough future for full window
ep_start = 0
ep_end = 1000
n_obs_steps = 8
frame_gap = 30
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
# Some future frames should be clamped
assert out_of_bounds.sum() > 0
# All indices should be < ep_end
assert (indices < ep_end).all()
# Center index should be frame_idx
half_steps = n_obs_steps // 2
assert indices[half_steps] == frame_idx
def test_sequence_is_monotonic(self):
"""Frame indices should be monotonically increasing."""
for frame_idx in [50, 100, 300, 950]:
indices, _ = compute_absolute_indices(frame_idx, 0, 1000, 8, 30)
# Check monotonic (non-decreasing due to clamping)
diffs = indices[1:] - indices[:-1]
assert (diffs >= 0).all(), f"Non-monotonic at frame {frame_idx}"
class TestComputeTau:
"""Tests for compute_tau (within-subtask progress).
Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
"""
def test_at_start(self):
"""τ should be 0 at subtask start."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_at_end(self):
"""τ should be 1 at subtask end."""
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_at_middle(self):
"""τ should be 0.5 at subtask midpoint."""
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
assert abs(tau - 0.5) < 1e-6
def test_quarter_progress(self):
"""Test τ at 25% through subtask."""
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
assert abs(tau - 0.25) < 1e-6
def test_zero_duration_subtask(self):
"""τ should be 1.0 for zero-duration subtask."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
assert tau == 1.0
def test_clamps_below_zero(self):
"""τ should be clamped to 0 if frame is before subtask."""
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_clamps_above_one(self):
"""τ should be clamped to 1 if frame is after subtask."""
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_float_inputs(self):
"""Test with float frame indices (from interpolation)."""
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
expected = (25.5 - 10.0) / (50.0 - 10.0)
assert abs(tau - expected) < 1e-6
class TestFindStageAndTau:
"""Tests for find_stage_and_tau logic.
This function is the core of progress label computation. It determines
which stage a frame belongs to and the within-stage progress (tau).
"""
def test_single_stage_mode_linear_progress(self):
"""Single-stage mode should give linear progress from 0 to 1."""
episode_length = 100
# Frame 0 -> tau = 0
stage, tau = find_stage_and_tau(0, episode_length, None, None, None, ["task"], {"task": 1.0})
assert stage == 0
assert abs(tau - 0.0) < 1e-6
# Frame 50 -> tau = 0.505 (50/99)
stage, tau = find_stage_and_tau(50, episode_length, None, None, None, ["task"], {"task": 1.0})
assert stage == 0
assert abs(tau - 50 / 99) < 1e-6
# Frame 99 -> tau = 1.0
stage, tau = find_stage_and_tau(99, episode_length, None, None, None, ["task"], {"task": 1.0})
assert stage == 0
assert abs(tau - 1.0) < 1e-6
def test_multi_stage_within_subtask(self):
"""Test finding stage when frame is within a subtask."""
global_names = ["reach", "grasp", "lift"]
proportions = {"reach": 0.3, "grasp": 0.2, "lift": 0.5}
subtask_names = ["reach", "grasp", "lift"]
subtask_starts = [0, 30, 50]
subtask_ends = [29, 49, 99]
# Frame 15 in "reach" stage (index 0)
stage, tau = find_stage_and_tau(
15, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 0
assert abs(tau - 15 / 29) < 1e-6
# Frame 40 in "grasp" stage (index 1)
stage, tau = find_stage_and_tau(
40, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 1
# tau = (40 - 30) / (49 - 30) = 10/19
assert abs(tau - 10 / 19) < 1e-6
# Frame 75 in "lift" stage (index 2)
stage, tau = find_stage_and_tau(
75, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 2
# tau = (75 - 50) / (99 - 50) = 25/49
assert abs(tau - 25 / 49) < 1e-6
def test_frame_at_subtask_boundaries(self):
"""Test frames exactly at subtask boundaries."""
global_names = ["a", "b"]
proportions = {"a": 0.5, "b": 0.5}
subtask_names = ["a", "b"]
subtask_starts = [0, 50]
subtask_ends = [49, 99]
# Frame at start of first subtask
stage, tau = find_stage_and_tau(
0, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 0
assert tau == 0.0
# Frame at end of first subtask
stage, tau = find_stage_and_tau(
49, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 0
assert tau == 1.0
# Frame at start of second subtask
stage, tau = find_stage_and_tau(
50, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 1
assert tau == 0.0
def test_frame_after_last_subtask(self):
"""Frames after last subtask should return last stage with high tau."""
global_names = ["a", "b"]
proportions = {"a": 0.5, "b": 0.5}
subtask_names = ["a", "b"]
subtask_starts = [0, 30]
subtask_ends = [29, 59]
# Frame 80 is after last subtask
stage, tau = find_stage_and_tau(
80, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
)
assert stage == 1 # Last stage
assert tau == 0.999 # Nearly complete
class TestEndToEndProgressLabeling:
"""End-to-end tests for progress label computation using normalize_stage_tau."""
def test_consistent_semantic_meaning(self):
"""Test that same subtask completion maps to same progress across trajectories.
This is the key semantic property: "end of subtask 1" should always
mean the same progress value regardless of trajectory speed.
"""
proportions = [0.3, 0.5, 0.2]
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
tau_fast = compute_tau(30, 0, 30) # = 1.0
y_fast = normalize_stage_tau(0 + tau_fast, temporal_proportions=proportions)
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
tau_slow = compute_tau(90, 0, 90) # = 1.0
y_slow = normalize_stage_tau(0 + tau_slow, temporal_proportions=proportions)
# Both should map to same progress (0.3 = end of subtask 1)
assert abs(y_fast - y_slow) < 1e-6
assert abs(y_fast - 0.3) < 1e-6
def test_monotonic_within_subtask(self):
"""Test that progress is monotonically increasing within a subtask."""
proportions = [0.4, 0.6]
prev_y = -1
for tau in np.linspace(0, 1, 11):
y = normalize_stage_tau(0 + tau, temporal_proportions=proportions)
assert y > prev_y or (tau == 0 and y == 0)
prev_y = y
def test_continuous_across_subtasks(self):
"""Test that progress is continuous at subtask boundaries."""
proportions = [0.3, 0.5, 0.2]
# End of subtask 0 (stage=0, tau=1.0) -> stage.tau = 1.0
y_end_0 = normalize_stage_tau(0 + 1.0, temporal_proportions=proportions)
# Start of subtask 1 (stage=1, tau=0.0) -> stage.tau = 1.0
y_start_1 = normalize_stage_tau(1 + 0.0, temporal_proportions=proportions)
# Should be equal (P_1 = 0.3)
assert abs(y_end_0 - y_start_1) < 1e-6
# End of subtask 1 (stage=1, tau=1.0) -> stage.tau = 2.0
y_end_1 = normalize_stage_tau(1 + 1.0, temporal_proportions=proportions)
# Start of subtask 2 (stage=2, tau=0.0) -> stage.tau = 2.0
y_start_2 = normalize_stage_tau(2 + 0.0, temporal_proportions=proportions)
# Should be equal (P_2 = 0.8)
assert abs(y_end_1 - y_start_2) < 1e-6
class TestTemporalProportionsToBreakpoints:
"""Tests for temporal_proportions_to_breakpoints.
Converts temporal proportions to cumulative breakpoints for normalization.
Example: [0.3, 0.5, 0.2] -> [0.0, 0.3, 0.8, 1.0]
"""
def test_basic_conversion(self):
"""Test basic conversion from proportions to breakpoints."""
proportions = [0.3, 0.5, 0.2]
breakpoints = temporal_proportions_to_breakpoints(proportions)
assert breakpoints is not None
assert len(breakpoints) == 4
assert breakpoints[0] == 0.0
assert abs(breakpoints[1] - 0.3) < 1e-6
assert abs(breakpoints[2] - 0.8) < 1e-6
assert breakpoints[3] == 1.0
def test_dict_input(self):
"""Test with dict input."""
proportions = {"a": 0.25, "b": 0.25, "c": 0.5}
breakpoints = temporal_proportions_to_breakpoints(proportions)
assert breakpoints is not None
assert len(breakpoints) == 4
assert breakpoints[0] == 0.0
assert breakpoints[-1] == 1.0
def test_dict_with_subtask_names_order(self):
"""Test that subtask_names determines order for dict input."""
proportions = {"c": 0.5, "a": 0.2, "b": 0.3} # Dict order
subtask_names = ["a", "b", "c"] # Different order
breakpoints = temporal_proportions_to_breakpoints(proportions, subtask_names)
# Breakpoints should follow subtask_names order: a=0.2, b=0.3, c=0.5
assert abs(breakpoints[1] - 0.2) < 1e-6 # a
assert abs(breakpoints[2] - 0.5) < 1e-6 # a + b = 0.5
assert breakpoints[3] == 1.0 # a + b + c = 1.0
def test_uniform_proportions(self):
"""Test with uniform proportions."""
proportions = [0.25, 0.25, 0.25, 0.25]
breakpoints = temporal_proportions_to_breakpoints(proportions)
expected = [0.0, 0.25, 0.5, 0.75, 1.0]
for i, (bp, exp) in enumerate(zip(breakpoints, expected, strict=True)):
assert abs(bp - exp) < 1e-6, f"Breakpoint {i} mismatch"
def test_none_input(self):
"""Test that None input returns None."""
result = temporal_proportions_to_breakpoints(None)
assert result is None
def test_normalization(self):
"""Test that non-normalized proportions are normalized."""
# Proportions sum to 2.0, not 1.0
proportions = [0.6, 1.0, 0.4]
breakpoints = temporal_proportions_to_breakpoints(proportions)
# Should be normalized: [0.3, 0.5, 0.2] -> [0, 0.3, 0.8, 1.0]
assert breakpoints[-1] == 1.0
assert abs(breakpoints[1] - 0.3) < 1e-6
class TestNormalizeStageTau:
"""Tests for normalize_stage_tau.
Normalizes stage+tau values to [0, 1] using breakpoints.
"""
def test_linear_fallback(self):
"""Test linear normalization when only num_stages is provided."""
# 4 stages, linear: [0, 0.25, 0.5, 0.75, 1.0]
# Stage 0 start
assert normalize_stage_tau(0.0, num_stages=4) == 0.0
# Stage 0 end / Stage 1 start
assert abs(normalize_stage_tau(1.0, num_stages=4) - 0.25) < 1e-6
# Stage 1 middle
assert abs(normalize_stage_tau(1.5, num_stages=4) - 0.375) < 1e-6
# Stage 3 end
assert normalize_stage_tau(4.0, num_stages=4) == 1.0
def test_with_custom_breakpoints(self):
"""Test with custom breakpoints."""
# Non-linear breakpoints
breakpoints = [0.0, 0.1, 0.5, 1.0] # 3 stages
# Stage 0: maps [0, 1) to [0.0, 0.1)
assert abs(normalize_stage_tau(0.5, breakpoints=breakpoints) - 0.05) < 1e-6
# Stage 1: maps [1, 2) to [0.1, 0.5)
assert abs(normalize_stage_tau(1.5, breakpoints=breakpoints) - 0.3) < 1e-6
# Stage 2: maps [2, 3) to [0.5, 1.0)
assert abs(normalize_stage_tau(2.5, breakpoints=breakpoints) - 0.75) < 1e-6
def test_with_temporal_proportions(self):
"""Test with temporal proportions (auto-computed breakpoints)."""
proportions = {"a": 0.2, "b": 0.3, "c": 0.5}
subtask_names = ["a", "b", "c"]
# Stage 0 end should map to 0.2
result = normalize_stage_tau(1.0, temporal_proportions=proportions, subtask_names=subtask_names)
assert abs(result - 0.2) < 1e-6
# Stage 1 end should map to 0.5
result = normalize_stage_tau(2.0, temporal_proportions=proportions, subtask_names=subtask_names)
assert abs(result - 0.5) < 1e-6
def test_tensor_input(self):
"""Test with tensor input."""
x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0])
breakpoints = [0.0, 0.3, 0.8, 1.0] # 3 stages
result = normalize_stage_tau(x, breakpoints=breakpoints)
assert isinstance(result, torch.Tensor)
assert result.shape == x.shape
assert abs(result[0].item() - 0.0) < 1e-6
assert abs(result[2].item() - 0.3) < 1e-6 # End of stage 0
assert abs(result[4].item() - 0.8) < 1e-6 # End of stage 1
def test_clamping(self):
"""Test that output is clamped to [0, 1]."""
# Below 0
assert normalize_stage_tau(-0.5, num_stages=4) == 0.0
# Above num_stages
assert normalize_stage_tau(5.0, num_stages=4) == 1.0
def test_batch_tensor(self):
"""Test with batched tensor."""
x = torch.tensor([[0.0, 1.0, 2.0], [0.5, 1.5, 2.5]]) # (2, 3)
result = normalize_stage_tau(x, num_stages=3)
assert result.shape == (2, 3)
assert (result >= 0).all()
assert (result <= 1).all()
def test_requires_one_of_inputs(self):
"""Test that at least one input method is required."""
with pytest.raises(ValueError):
normalize_stage_tau(1.0)
class TestRewindAugmentation:
"""Tests for rewind augmentation logic with bidirectional observation sampling.
Rewind appends frames before the earliest observation frame, going backwards.
With bidirectional sampling centered at frame_idx:
- Earliest obs frame = frame_idx - half_steps * frame_gap
- Rewind goes backwards from that point
"""
def test_rewind_indices_go_backwards_from_earliest_obs(self):
"""Rewind indices should go backwards from earliest observation frame."""
frame_idx = 300 # Center of bidirectional window
ep_start = 0
n_obs_steps = 4 # half_steps = 2
frame_gap = 30
# Earliest obs frame = 300 - 2*30 = 240
# Rewind goes backwards: 210, 180
rewind_step, rewind_indices = apply_rewind_augmentation(
frame_idx,
ep_start,
n_obs_steps=n_obs_steps,
max_rewind_steps=2,
frame_gap=frame_gap,
rewind_step=2,
)
assert rewind_step == 2
assert len(rewind_indices) == 2
# First rewind frame is closest to obs window, second is further back
assert rewind_indices[0] == 210 # 240 - 30
assert rewind_indices[1] == 180 # 240 - 60
assert rewind_indices[0] > rewind_indices[1], "Rewind should be descending"
def test_rewind_goes_backward_through_history(self):
"""Rewind frames should go backward before the observation window."""
frame_idx = 450 # Center of bidirectional window
ep_start = 0
n_obs_steps = 8 # half_steps = 4
frame_gap = 30
# Earliest obs frame = 450 - 4*30 = 330
# Rewind from 330: [300, 270, 240]
rewind_step, rewind_indices = apply_rewind_augmentation(
frame_idx,
ep_start,
n_obs_steps=n_obs_steps,
max_rewind_steps=4,
frame_gap=frame_gap,
rewind_step=3,
)
assert rewind_step == 3
expected = [300, 270, 240] # Going backwards from 330
assert rewind_indices == expected
def test_no_rewind_when_obs_window_at_episode_start(self):
"""No rewind when observation window reaches episode start."""
frame_idx = 120 # Center of window
ep_start = 0
n_obs_steps = 8 # half_steps = 4
frame_gap = 30
# Earliest obs frame = 120 - 4*30 = 0 (at episode start)
rewind_step, rewind_indices = apply_rewind_augmentation(
frame_idx, ep_start, n_obs_steps=n_obs_steps, max_rewind_steps=4, frame_gap=frame_gap
)
# No room for rewind
assert rewind_step == 0
assert rewind_indices == []
def test_rewind_targets_are_decreasing(self):
"""Progress targets for rewind frames should be decreasing."""
# Simulate progress values
obs_progress = [0.1, 0.2, 0.3, 0.4, 0.5] # Forward progress
# Rewind reverses progress
rewind_indices = [4, 3, 2] # Go backwards through indices
rewind_progress = [obs_progress[i] for i in rewind_indices]
# Should be decreasing
for i in range(len(rewind_progress) - 1):
assert rewind_progress[i] > rewind_progress[i + 1]