mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
Reward models refactor (#3142)
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes * refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/ * refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/ * refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py * refactor(rewards): update imports and delete old reward model locations * test(rewards): add reward model tests and update existing test imports * fix(rewards): restore full Classifier and SARM implementations * test(rewards): restore missing CUDA and mixed precision classifier processor tests * refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train * refactor(lerobot_train.py): add missing sampling weight script * linter + missing files * add testing for sampl weighter * revert some useless changes, improve typing * update docs * add automatic detection of the progress path * remove type exp * improve comment * fix: move rabc.py to rewards/sarm/ and update import paths * refactor(imports): update reward model imports to new module structure * refactor(imports): update reward model imports to reflect new module structure * refactor(imports): conditionally import pandas based on availability * feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig * refactor(policies): remove reward model branches from policy factory and __init__ * refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash * feat(train): route reward model training through rewards/factory instead of policies/factory * refactor(train): streamline reward model training logic * fix(rewards): ensure FileNotFoundError is raised for missing config_file * refactor(train): update __get_path_fields__ to include reward_model for config loading * refactor(classifier): remove redundant input normalization in predict_reward method * fix(train): raise ValueError for non-trainable reward models in train function * refactor(pretrained_rm): add model card template * refactor(tests): reward models * refactor(sarm): update reset method and remove unused action prediction methods * refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function * fix(train): raise ValueError for PEFT usage in reward model training * refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -1,153 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
|
||||
def test_classifier_output():
|
||||
output = ClassifierOutput(
|
||||
logits=torch.tensor([1, 2, 3]),
|
||||
probabilities=torch.tensor([0.1, 0.2, 0.3]),
|
||||
hidden_states=None,
|
||||
)
|
||||
|
||||
assert (
|
||||
f"{output}"
|
||||
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
|
||||
)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
config.num_cameras = 1
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
|
||||
REWARD: torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
assert len(images) == 1
|
||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||
assert labels.shape == torch.Size([batch_size])
|
||||
|
||||
output = classifier.predict(images)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.size() == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||
}
|
||||
config.num_cameras = 1
|
||||
config.num_classes = num_classes
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
|
||||
REWARD: torch.rand((batch_size, num_classes)),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
assert len(images) == 1
|
||||
assert images[0].shape == torch.Size([batch_size, 3, 128, 128])
|
||||
assert labels.shape == torch.Size([batch_size, num_classes])
|
||||
|
||||
output = classifier.predict(images)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 256])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_default_device():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig(device="cpu")
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
@@ -1,694 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("faker")
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
|
||||
class MockDatasetMeta:
|
||||
"""Mock dataset metadata for testing processor."""
|
||||
|
||||
def __init__(self, episodes: list[dict]):
|
||||
self._episodes = episodes
|
||||
|
||||
@property
|
||||
def episodes(self):
|
||||
"""Return episodes as a mock object with to_pandas() method."""
|
||||
mock = MagicMock()
|
||||
mock.__len__ = lambda s: len(self._episodes)
|
||||
mock.__getitem__ = lambda s, idx: self._episodes[idx]
|
||||
mock.to_pandas = lambda: pd.DataFrame(self._episodes)
|
||||
return mock
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock SARMConfig for testing processor methods."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_obs_steps: int = 8,
|
||||
max_rewind_steps: int = 4,
|
||||
frame_gap: int = 30,
|
||||
sparse_subtask_names: list = None,
|
||||
sparse_temporal_proportions: list = None,
|
||||
dense_subtask_names: list = None,
|
||||
dense_temporal_proportions: list = None,
|
||||
image_key: str = "observation.images.top",
|
||||
state_key: str = "observation.state",
|
||||
max_state_dim: int = 32,
|
||||
device: str = None,
|
||||
rewind_probability: float = 0.8,
|
||||
language_perturbation_probability: float = 0.2,
|
||||
annotation_mode: str = "dual",
|
||||
clip_batch_size: int = 64,
|
||||
text_dim: int = 512,
|
||||
):
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.max_rewind_steps = max_rewind_steps
|
||||
self.frame_gap = frame_gap
|
||||
self.sparse_subtask_names = sparse_subtask_names or ["task"]
|
||||
self.sparse_temporal_proportions = sparse_temporal_proportions or [1.0]
|
||||
self.dense_subtask_names = dense_subtask_names
|
||||
self.dense_temporal_proportions = dense_temporal_proportions
|
||||
self.uses_dual_heads = annotation_mode in ["dense_only", "dual"]
|
||||
self.image_key = image_key
|
||||
self.state_key = state_key
|
||||
self.max_state_dim = max_state_dim
|
||||
self.device = device
|
||||
self.rewind_probability = rewind_probability
|
||||
self.language_perturbation_probability = language_perturbation_probability
|
||||
self.annotation_mode = annotation_mode
|
||||
self.clip_batch_size = clip_batch_size
|
||||
self.text_dim = text_dim
|
||||
|
||||
# Compute observation delta indices (same as config: bidirectional)
|
||||
half_steps = self.n_obs_steps // 2
|
||||
past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
|
||||
obs_deltas = past_deltas + [0] + future_deltas
|
||||
rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
|
||||
self.observation_delta_indices = obs_deltas + rewind_deltas
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return 1 + self.n_obs_steps + self.max_rewind_steps
|
||||
|
||||
|
||||
class TestSARMEncodingProcessorStepEndToEnd:
|
||||
"""End-to-end test for SARMEncodingProcessorStep with dummy batch data."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clip_model(self):
|
||||
"""Mock CLIP model to avoid loading real weights."""
|
||||
with (
|
||||
patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls,
|
||||
patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
|
||||
):
|
||||
# Mock the CLIP model - return embeddings based on input batch size
|
||||
mock_model = MagicMock()
|
||||
|
||||
def get_image_features_side_effect(**kwargs):
|
||||
pixel_values = kwargs.get("pixel_values")
|
||||
batch_size = pixel_values.shape[0] if pixel_values is not None else 1
|
||||
return torch.randn(batch_size, 512)
|
||||
|
||||
mock_model.get_image_features.side_effect = get_image_features_side_effect
|
||||
mock_model.get_text_features.return_value = torch.randn(1, 512)
|
||||
mock_model.to.return_value = mock_model
|
||||
mock_model_cls.from_pretrained.return_value = mock_model
|
||||
|
||||
# Mock the CLIP processor - return tensors based on input images
|
||||
mock_processor = MagicMock()
|
||||
|
||||
def processor_side_effect(images=None, **kwargs):
|
||||
num_images = len(images) if images is not None else 1
|
||||
return {
|
||||
"pixel_values": torch.randn(num_images, 3, 224, 224),
|
||||
}
|
||||
|
||||
mock_processor.side_effect = processor_side_effect
|
||||
# Mock tokenizer for text encoding
|
||||
mock_processor.tokenizer.return_value = {
|
||||
"input_ids": torch.ones(1, 77, dtype=torch.long),
|
||||
"attention_mask": torch.ones(1, 77, dtype=torch.long),
|
||||
}
|
||||
mock_processor_cls.from_pretrained.return_value = mock_processor
|
||||
|
||||
yield mock_model, mock_processor
|
||||
|
||||
@pytest.fixture
|
||||
def processor_with_mocks(self, mock_clip_model):
|
||||
"""Create a processor with mocked CLIP and dataset metadata for dual mode."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
# Dual mode config with both sparse and dense annotations
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=30,
|
||||
rewind_probability=0.0, # Disable for deterministic test
|
||||
language_perturbation_probability=0.0, # Disable for deterministic test
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["reach", "grasp", "lift"],
|
||||
sparse_temporal_proportions=[0.3, 0.4, 0.3],
|
||||
dense_subtask_names=["approach", "contact", "close_gripper", "lift_up"],
|
||||
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||
)
|
||||
|
||||
# Create mock dataset metadata with one episode of 300 frames
|
||||
# Include annotation columns for dual mode
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 300,
|
||||
"task": "pick up the cube",
|
||||
"sparse_subtask_names": ["reach", "grasp", "lift"],
|
||||
"sparse_subtask_start_frames": [0, 90, 210],
|
||||
"sparse_subtask_end_frames": [90, 210, 300],
|
||||
"dense_subtask_names": ["approach", "contact", "close_gripper", "lift_up"],
|
||||
"dense_subtask_start_frames": [0, 75, 150, 225],
|
||||
"dense_subtask_end_frames": [75, 150, 225, 300],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(
|
||||
config=config,
|
||||
dataset_meta=dataset_meta,
|
||||
)
|
||||
processor.train(True) # Use train() method, not direct assignment
|
||||
|
||||
return processor, config
|
||||
|
||||
def test_call_with_single_frame_batch(self, processor_with_mocks):
|
||||
"""Test processor __call__ with a single-frame batch."""
|
||||
processor, config = processor_with_mocks
|
||||
|
||||
# Create dummy input transition
|
||||
batch_size = 1
|
||||
num_frames = config.num_frames # 13 frames (9 obs + 4 rewind)
|
||||
|
||||
# Image: (T, C, H, W) format as expected by processor
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
|
||||
# State: (T, D) format
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 150, # Middle of episode
|
||||
"episode_index": 0,
|
||||
"task": "pick up the cube",
|
||||
},
|
||||
}
|
||||
|
||||
# Run processor
|
||||
result = processor(transition)
|
||||
|
||||
# Verify output structure
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check video features exist and have correct shape
|
||||
assert "video_features" in obs
|
||||
video_features = obs["video_features"]
|
||||
assert video_features.shape[0] == batch_size
|
||||
assert video_features.shape[1] == num_frames
|
||||
assert video_features.shape[2] == 512 # CLIP embedding dim
|
||||
|
||||
# Check state features exist and have correct shape
|
||||
assert "state_features" in obs
|
||||
state_features = obs["state_features"]
|
||||
assert state_features.shape[0] == batch_size
|
||||
assert state_features.shape[1] == num_frames
|
||||
assert state_features.shape[2] == config.max_state_dim # Padded to max_state_dim
|
||||
|
||||
# Check text features exist and have correct shape
|
||||
assert "text_features" in obs
|
||||
text_features = obs["text_features"]
|
||||
assert text_features.shape[0] == batch_size
|
||||
assert text_features.shape[1] == 512 # CLIP embedding dim
|
||||
|
||||
# Check lengths tensor
|
||||
assert "lengths" in obs
|
||||
lengths = obs["lengths"]
|
||||
assert lengths.shape[0] == batch_size
|
||||
assert lengths.dtype == torch.int32
|
||||
|
||||
# Check sparse_targets exist
|
||||
assert "sparse_targets" in obs
|
||||
sparse_targets = obs["sparse_targets"]
|
||||
assert sparse_targets.shape == (batch_size, num_frames)
|
||||
# All targets should be in [0, max_stages] range (stage.tau format)
|
||||
assert (sparse_targets >= 0).all()
|
||||
|
||||
# Check dense_targets exist (for dual mode)
|
||||
assert "dense_targets" in obs
|
||||
dense_targets = obs["dense_targets"]
|
||||
assert dense_targets.shape == (batch_size, num_frames)
|
||||
assert (dense_targets >= 0).all()
|
||||
|
||||
def test_call_with_batched_input(self, mock_clip_model):
|
||||
"""Test processor __call__ with a batched input (multiple frames) in dual mode."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=30,
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["reach", "grasp"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["step1", "step2", "step3"],
|
||||
dense_temporal_proportions=[0.33, 0.34, 0.33],
|
||||
)
|
||||
|
||||
# Two episodes with different lengths, each with sparse+dense annotations
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 200,
|
||||
"task": "task A",
|
||||
"sparse_subtask_names": ["reach", "grasp"],
|
||||
"sparse_subtask_start_frames": [0, 100],
|
||||
"sparse_subtask_end_frames": [100, 200],
|
||||
"dense_subtask_names": ["step1", "step2", "step3"],
|
||||
"dense_subtask_start_frames": [0, 66, 133],
|
||||
"dense_subtask_end_frames": [66, 133, 200],
|
||||
},
|
||||
{
|
||||
"dataset_from_index": 200,
|
||||
"dataset_to_index": 500,
|
||||
"task": "task B",
|
||||
"sparse_subtask_names": ["reach", "grasp"],
|
||||
"sparse_subtask_start_frames": [200, 350],
|
||||
"sparse_subtask_end_frames": [350, 500],
|
||||
"dense_subtask_names": ["step1", "step2", "step3"],
|
||||
"dense_subtask_start_frames": [200, 300, 400],
|
||||
"dense_subtask_end_frames": [300, 400, 500],
|
||||
},
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
batch_size = 2
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Image: (B, T, C, H, W) format
|
||||
dummy_image = np.random.rand(batch_size, num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(batch_size, num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": np.array([100, 350]), # One frame from each episode
|
||||
"episode_index": np.array([0, 1]),
|
||||
"task": ["task A", "task B"],
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Verify batch dimension is preserved for all outputs
|
||||
assert obs["video_features"].shape[0] == batch_size
|
||||
assert obs["state_features"].shape[0] == batch_size
|
||||
assert obs["lengths"].shape[0] == batch_size
|
||||
assert obs["sparse_targets"].shape[0] == batch_size
|
||||
assert obs["dense_targets"].shape[0] == batch_size # Dual mode has dense targets
|
||||
|
||||
def test_targets_increase_with_progress(self, mock_clip_model):
|
||||
"""Test that both sparse and dense targets increase as frame index progresses."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=30,
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["phase1", "phase2"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["a", "b", "c", "d"],
|
||||
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||
)
|
||||
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 300,
|
||||
"task": "test task",
|
||||
"sparse_subtask_names": ["phase1", "phase2"],
|
||||
"sparse_subtask_start_frames": [0, 150],
|
||||
"sparse_subtask_end_frames": [150, 300],
|
||||
"dense_subtask_names": ["a", "b", "c", "d"],
|
||||
"dense_subtask_start_frames": [0, 75, 150, 225],
|
||||
"dense_subtask_end_frames": [75, 150, 225, 300],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Test at early, middle, and late points in episode
|
||||
frame_indices = [30, 150, 270]
|
||||
sparse_center_targets = []
|
||||
dense_center_targets = []
|
||||
|
||||
for frame_idx in frame_indices:
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": frame_idx,
|
||||
"episode_index": 0,
|
||||
"task": "test task",
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
# Get target at center frame (index 4 in 9-frame observation window)
|
||||
sparse_center_targets.append(obs["sparse_targets"][0, 4].item())
|
||||
dense_center_targets.append(obs["dense_targets"][0, 4].item())
|
||||
|
||||
# Both sparse and dense targets should increase with frame index
|
||||
assert sparse_center_targets[0] < sparse_center_targets[2], (
|
||||
f"Early sparse target ({sparse_center_targets[0]}) should be < late ({sparse_center_targets[2]})"
|
||||
)
|
||||
assert dense_center_targets[0] < dense_center_targets[2], (
|
||||
f"Early dense target ({dense_center_targets[0]}) should be < late ({dense_center_targets[2]})"
|
||||
)
|
||||
|
||||
def test_progress_labels_exact_values(self, mock_clip_model):
|
||||
"""Test that progress labels (stage.tau) are computed correctly for known positions."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
# Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=10, # Smaller gap for easier calculation
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["A", "B"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["d1", "d2", "d3", "d4"],
|
||||
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||
)
|
||||
|
||||
# Episode: frames 0-99, sparse stages at [0-49], [50-99]
|
||||
# Dense stages at [0-24], [25-49], [50-74], [75-99]
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 100,
|
||||
"task": "test",
|
||||
"sparse_subtask_names": ["A", "B"],
|
||||
"sparse_subtask_start_frames": [0, 50],
|
||||
"sparse_subtask_end_frames": [50, 100],
|
||||
"dense_subtask_names": ["d1", "d2", "d3", "d4"],
|
||||
"dense_subtask_start_frames": [0, 25, 50, 75],
|
||||
"dense_subtask_end_frames": [25, 50, 75, 100],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Test at frame 50 (center of episode)
|
||||
# With frame_gap=10, n_obs_steps=8:
|
||||
# obs indices around frame 50: [10, 20, 30, 40, 50, 60, 70, 80, 90] (9 frames)
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 50,
|
||||
"episode_index": 0,
|
||||
"task": "test",
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
sparse_targets = obs["sparse_targets"][0] # (13,)
|
||||
dense_targets = obs["dense_targets"][0] # (13,)
|
||||
|
||||
# First 9 frames are observation frames, last 4 are rewind placeholders (zeros when no rewind)
|
||||
# Check that obs frames have non-zero targets
|
||||
obs_sparse = sparse_targets[:9]
|
||||
obs_dense = dense_targets[:9]
|
||||
|
||||
# Verify targets are monotonically increasing for observation frames
|
||||
for i in range(1, 9):
|
||||
assert obs_sparse[i] >= obs_sparse[i - 1], (
|
||||
f"Sparse targets should be monotonic: {obs_sparse[i - 1].item():.3f} -> {obs_sparse[i].item():.3f}"
|
||||
)
|
||||
assert obs_dense[i] >= obs_dense[i - 1], (
|
||||
f"Dense targets should be monotonic: {obs_dense[i - 1].item():.3f} -> {obs_dense[i].item():.3f}"
|
||||
)
|
||||
|
||||
# Rewind slots should be zero when rewind is disabled
|
||||
rewind_targets = sparse_targets[9:]
|
||||
assert (rewind_targets == 0).all(), "Rewind slots should be zero when rewind is disabled"
|
||||
|
||||
# Check stage transitions: frame 50 is at boundary of sparse stage A->B
|
||||
# Center frame (index 4) corresponds to actual frame 50
|
||||
center_sparse = obs_sparse[4].item()
|
||||
# At frame 50, sparse stage B starts, so target should be ~1.0 (stage 1 + tau 0)
|
||||
assert 0.9 <= center_sparse <= 1.1, (
|
||||
f"At sparse boundary, target should be ~1.0, got {center_sparse:.3f}"
|
||||
)
|
||||
|
||||
def test_rewind_augmentation_applied(self, mock_clip_model):
|
||||
"""Test that rewind augmentation correctly extends sequence and generates targets."""
|
||||
import random
|
||||
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=10,
|
||||
rewind_probability=1.0, # Always apply rewind
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["A", "B"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["d1", "d2"],
|
||||
dense_temporal_proportions=[0.5, 0.5],
|
||||
)
|
||||
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 200,
|
||||
"task": "test",
|
||||
"sparse_subtask_names": ["A", "B"],
|
||||
"sparse_subtask_start_frames": [0, 100],
|
||||
"sparse_subtask_end_frames": [100, 200],
|
||||
"dense_subtask_names": ["d1", "d2"],
|
||||
"dense_subtask_start_frames": [0, 100],
|
||||
"dense_subtask_end_frames": [100, 200],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames # 13
|
||||
|
||||
# Test at frame 150 (center of bidirectional window)
|
||||
# With n_obs_steps=8, half_steps=4, frame_gap=10:
|
||||
# - Earliest obs frame = 150 - 4*10 = 110
|
||||
# - Rewind can go back from 110 to frames like 100, 90, 80, 70
|
||||
# - History available = 110 - 0 = 110, so max rewind = 110/10 = 11 (capped at 4)
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 150,
|
||||
"episode_index": 0,
|
||||
"task": "test",
|
||||
},
|
||||
}
|
||||
|
||||
# Seed random for reproducibility
|
||||
random.seed(42)
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
lengths = obs["lengths"][0].item()
|
||||
sparse_targets = obs["sparse_targets"][0]
|
||||
|
||||
# With rewind_probability=1.0 and enough history, lengths should be > 9 (9 obs + some rewind)
|
||||
assert lengths > 9, f"With rewind enabled, lengths should be > 9, got {lengths}"
|
||||
assert lengths <= num_frames, f"Lengths should not exceed total frames {num_frames}, got {lengths}"
|
||||
|
||||
# Rewind targets should be non-zero for frames within valid length
|
||||
n_obs_frames = 9
|
||||
rewind_count = lengths - n_obs_frames
|
||||
|
||||
if rewind_count > 0:
|
||||
# Check that rewind frames have targets
|
||||
rewind_targets = sparse_targets[n_obs_frames : n_obs_frames + rewind_count]
|
||||
# Rewind frames are from BEFORE the earliest obs frame (110)
|
||||
# These frames (100, 90, 80, 70) are earlier in the episode
|
||||
earliest_obs_target = sparse_targets[0].item() # Frame 110
|
||||
|
||||
# Rewind targets should be less than earliest obs (they're from earlier frames)
|
||||
for i, rt in enumerate(rewind_targets):
|
||||
assert rt.item() < earliest_obs_target, (
|
||||
f"Rewind target {i} ({rt.item():.3f}) should be < earliest obs ({earliest_obs_target:.3f})"
|
||||
)
|
||||
|
||||
# Rewind targets should be decreasing (going further back in time)
|
||||
for i in range(1, len(rewind_targets)):
|
||||
assert rewind_targets[i] <= rewind_targets[i - 1], (
|
||||
f"Rewind targets should decrease: {rewind_targets[i - 1].item():.3f} -> {rewind_targets[i].item():.3f}"
|
||||
)
|
||||
|
||||
def test_full_sequence_target_consistency(self, mock_clip_model):
|
||||
"""Test that the full sequence of targets is consistent with frame positions."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
from lerobot.policies.sarm.sarm_utils import find_stage_and_tau
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=10,
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["s1", "s2", "s3"],
|
||||
sparse_temporal_proportions=[0.33, 0.34, 0.33],
|
||||
dense_subtask_names=["d1", "d2"],
|
||||
dense_temporal_proportions=[0.5, 0.5],
|
||||
)
|
||||
|
||||
# 3 sparse stages: [0-33), [33-66), [66-99]
|
||||
# 2 dense stages: [0-50), [50-100)
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 100,
|
||||
"task": "test",
|
||||
"sparse_subtask_names": ["s1", "s2", "s3"],
|
||||
"sparse_subtask_start_frames": [0, 33, 66],
|
||||
"sparse_subtask_end_frames": [33, 66, 100],
|
||||
"dense_subtask_names": ["d1", "d2"],
|
||||
"dense_subtask_start_frames": [0, 50],
|
||||
"dense_subtask_end_frames": [50, 100],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Test at frame 50 (middle of episode)
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 50,
|
||||
"episode_index": 0,
|
||||
"task": "test",
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
sparse_targets = obs["sparse_targets"][0]
|
||||
dense_targets = obs["dense_targets"][0]
|
||||
|
||||
# Manually compute expected targets for observation frames
|
||||
# With frame_gap=10, n_obs_steps=8, center at 50:
|
||||
# obs frames: [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||
expected_obs_frames = [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||
|
||||
sparse_names = ["s1", "s2", "s3"]
|
||||
sparse_starts = [0, 33, 66]
|
||||
sparse_ends = [33, 66, 100]
|
||||
sparse_props = {"s1": 0.33, "s2": 0.34, "s3": 0.33}
|
||||
|
||||
dense_names = ["d1", "d2"]
|
||||
dense_starts = [0, 50]
|
||||
dense_ends = [50, 100]
|
||||
dense_props = {"d1": 0.5, "d2": 0.5}
|
||||
|
||||
for i, frame in enumerate(expected_obs_frames):
|
||||
expected_sparse = find_stage_and_tau(
|
||||
frame,
|
||||
100,
|
||||
sparse_names,
|
||||
sparse_starts,
|
||||
sparse_ends,
|
||||
sparse_names,
|
||||
sparse_props,
|
||||
return_combined=True,
|
||||
)
|
||||
expected_dense = find_stage_and_tau(
|
||||
frame,
|
||||
100,
|
||||
dense_names,
|
||||
dense_starts,
|
||||
dense_ends,
|
||||
dense_names,
|
||||
dense_props,
|
||||
return_combined=True,
|
||||
)
|
||||
|
||||
actual_sparse = sparse_targets[i].item()
|
||||
actual_dense = dense_targets[i].item()
|
||||
|
||||
assert abs(actual_sparse - expected_sparse) < 0.01, (
|
||||
f"Frame {frame}: sparse mismatch {actual_sparse:.3f} vs expected {expected_sparse:.3f}"
|
||||
)
|
||||
assert abs(actual_dense - expected_dense) < 0.01, (
|
||||
f"Frame {frame}: dense mismatch {actual_dense:.3f} vs expected {expected_dense:.3f}"
|
||||
)
|
||||
@@ -1,615 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
apply_rewind_augmentation,
|
||||
compute_absolute_indices,
|
||||
compute_tau,
|
||||
find_stage_and_tau,
|
||||
normalize_stage_tau,
|
||||
temporal_proportions_to_breakpoints,
|
||||
)
|
||||
|
||||
|
||||
class TestProgressLabelsWithModes:
|
||||
"""End-to-end tests for progress label generation in different modes."""
|
||||
|
||||
def test_sparse_mode_single_stage(self):
|
||||
"""Sparse mode with single stage should give linear progress."""
|
||||
episode_length = 300
|
||||
global_names = ["task"]
|
||||
proportions = {"task": 1.0}
|
||||
|
||||
# Test at various frames
|
||||
for frame in [0, 100, 200, 299]:
|
||||
stage, tau = find_stage_and_tau(
|
||||
frame, episode_length, None, None, None, global_names, proportions
|
||||
)
|
||||
|
||||
expected_tau = frame / (episode_length - 1)
|
||||
assert stage == 0
|
||||
assert abs(tau - expected_tau) < 1e-5
|
||||
|
||||
def test_sparse_mode_multi_stage(self):
|
||||
"""Sparse mode with multiple stages."""
|
||||
global_names = ["reach", "grasp", "lift", "place"]
|
||||
proportions = {"reach": 0.2, "grasp": 0.2, "lift": 0.3, "place": 0.3}
|
||||
|
||||
subtask_names = ["reach", "grasp", "lift", "place"]
|
||||
subtask_starts = [0, 60, 120, 210]
|
||||
subtask_ends = [59, 119, 209, 299]
|
||||
|
||||
# Check stages are correctly identified
|
||||
stage_at_30, _ = find_stage_and_tau(
|
||||
30, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage_at_30 == 0
|
||||
|
||||
stage_at_90, _ = find_stage_and_tau(
|
||||
90, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage_at_90 == 1
|
||||
|
||||
stage_at_150, _ = find_stage_and_tau(
|
||||
150, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage_at_150 == 2
|
||||
|
||||
def test_dense_mode_more_stages(self):
|
||||
"""Dense mode should work with more fine-grained stages."""
|
||||
global_names = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
||||
proportions = dict.fromkeys(global_names, 1 / 8)
|
||||
|
||||
subtask_names = global_names
|
||||
subtask_starts = [i * 50 for i in range(8)]
|
||||
subtask_ends = [(i + 1) * 50 - 1 for i in range(8)]
|
||||
|
||||
# Each stage should occupy 50 frames
|
||||
for stage_idx in range(8):
|
||||
mid_frame = stage_idx * 50 + 25
|
||||
stage, _ = find_stage_and_tau(
|
||||
mid_frame, 400, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == stage_idx
|
||||
|
||||
|
||||
class TestComputeAbsoluteIndices:
|
||||
"""Tests for compute_absolute_indices (bidirectional sampling)."""
|
||||
|
||||
def test_no_clamping_when_in_middle(self):
|
||||
"""When frame is in middle of episode, no clamping should occur."""
|
||||
frame_idx = 300
|
||||
ep_start = 0
|
||||
ep_end = 1000
|
||||
n_obs_steps = 8
|
||||
frame_gap = 30
|
||||
|
||||
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
|
||||
|
||||
# All should be valid (no out of bounds)
|
||||
assert out_of_bounds.sum() == 0
|
||||
|
||||
# Check bidirectional indices: [-120, -90, -60, -30, 0, 30, 60, 90, 120] from center
|
||||
half_steps = n_obs_steps // 2
|
||||
expected = (
|
||||
[frame_idx - frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
+ [frame_idx]
|
||||
+ [frame_idx + frame_gap * i for i in range(1, half_steps + 1)]
|
||||
)
|
||||
assert indices.tolist() == expected
|
||||
|
||||
# Center frame (index 4) should be frame_idx
|
||||
assert indices[half_steps] == frame_idx
|
||||
|
||||
def test_clamping_at_episode_start(self):
|
||||
"""Early frames should be clamped to episode start."""
|
||||
frame_idx = 50 # Not enough history for full past window
|
||||
ep_start = 0
|
||||
ep_end = 1000
|
||||
n_obs_steps = 8
|
||||
frame_gap = 30
|
||||
|
||||
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
|
||||
|
||||
# Some past frames should be clamped (out_of_bounds = 1)
|
||||
assert out_of_bounds.sum() > 0
|
||||
|
||||
# All indices should be >= ep_start
|
||||
assert (indices >= ep_start).all()
|
||||
|
||||
# Center index should be frame_idx
|
||||
half_steps = n_obs_steps // 2
|
||||
assert indices[half_steps] == frame_idx
|
||||
|
||||
def test_clamping_at_episode_end(self):
|
||||
"""Late frames should be clamped to episode end."""
|
||||
frame_idx = 950 # Not enough future for full window
|
||||
ep_start = 0
|
||||
ep_end = 1000
|
||||
n_obs_steps = 8
|
||||
frame_gap = 30
|
||||
|
||||
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
|
||||
|
||||
# Some future frames should be clamped
|
||||
assert out_of_bounds.sum() > 0
|
||||
|
||||
# All indices should be < ep_end
|
||||
assert (indices < ep_end).all()
|
||||
|
||||
# Center index should be frame_idx
|
||||
half_steps = n_obs_steps // 2
|
||||
assert indices[half_steps] == frame_idx
|
||||
|
||||
def test_sequence_is_monotonic(self):
|
||||
"""Frame indices should be monotonically increasing."""
|
||||
for frame_idx in [50, 100, 300, 950]:
|
||||
indices, _ = compute_absolute_indices(frame_idx, 0, 1000, 8, 30)
|
||||
|
||||
# Check monotonic (non-decreasing due to clamping)
|
||||
diffs = indices[1:] - indices[:-1]
|
||||
assert (diffs >= 0).all(), f"Non-monotonic at frame {frame_idx}"
|
||||
|
||||
|
||||
class TestComputeTau:
|
||||
"""Tests for compute_tau (within-subtask progress).
|
||||
|
||||
Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
|
||||
"""
|
||||
|
||||
def test_at_start(self):
|
||||
"""τ should be 0 at subtask start."""
|
||||
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
|
||||
assert tau == 0.0
|
||||
|
||||
def test_at_end(self):
|
||||
"""τ should be 1 at subtask end."""
|
||||
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_at_middle(self):
|
||||
"""τ should be 0.5 at subtask midpoint."""
|
||||
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
|
||||
assert abs(tau - 0.5) < 1e-6
|
||||
|
||||
def test_quarter_progress(self):
|
||||
"""Test τ at 25% through subtask."""
|
||||
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
|
||||
assert abs(tau - 0.25) < 1e-6
|
||||
|
||||
def test_zero_duration_subtask(self):
|
||||
"""τ should be 1.0 for zero-duration subtask."""
|
||||
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_clamps_below_zero(self):
|
||||
"""τ should be clamped to 0 if frame is before subtask."""
|
||||
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
|
||||
assert tau == 0.0
|
||||
|
||||
def test_clamps_above_one(self):
|
||||
"""τ should be clamped to 1 if frame is after subtask."""
|
||||
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_float_inputs(self):
|
||||
"""Test with float frame indices (from interpolation)."""
|
||||
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
|
||||
expected = (25.5 - 10.0) / (50.0 - 10.0)
|
||||
assert abs(tau - expected) < 1e-6
|
||||
|
||||
|
||||
class TestFindStageAndTau:
|
||||
"""Tests for find_stage_and_tau logic.
|
||||
|
||||
This function is the core of progress label computation. It determines
|
||||
which stage a frame belongs to and the within-stage progress (tau).
|
||||
"""
|
||||
|
||||
def test_single_stage_mode_linear_progress(self):
|
||||
"""Single-stage mode should give linear progress from 0 to 1."""
|
||||
episode_length = 100
|
||||
|
||||
# Frame 0 -> tau = 0
|
||||
stage, tau = find_stage_and_tau(0, episode_length, None, None, None, ["task"], {"task": 1.0})
|
||||
assert stage == 0
|
||||
assert abs(tau - 0.0) < 1e-6
|
||||
|
||||
# Frame 50 -> tau = 0.505 (50/99)
|
||||
stage, tau = find_stage_and_tau(50, episode_length, None, None, None, ["task"], {"task": 1.0})
|
||||
assert stage == 0
|
||||
assert abs(tau - 50 / 99) < 1e-6
|
||||
|
||||
# Frame 99 -> tau = 1.0
|
||||
stage, tau = find_stage_and_tau(99, episode_length, None, None, None, ["task"], {"task": 1.0})
|
||||
assert stage == 0
|
||||
assert abs(tau - 1.0) < 1e-6
|
||||
|
||||
def test_multi_stage_within_subtask(self):
|
||||
"""Test finding stage when frame is within a subtask."""
|
||||
global_names = ["reach", "grasp", "lift"]
|
||||
proportions = {"reach": 0.3, "grasp": 0.2, "lift": 0.5}
|
||||
|
||||
subtask_names = ["reach", "grasp", "lift"]
|
||||
subtask_starts = [0, 30, 50]
|
||||
subtask_ends = [29, 49, 99]
|
||||
|
||||
# Frame 15 in "reach" stage (index 0)
|
||||
stage, tau = find_stage_and_tau(
|
||||
15, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 0
|
||||
assert abs(tau - 15 / 29) < 1e-6
|
||||
|
||||
# Frame 40 in "grasp" stage (index 1)
|
||||
stage, tau = find_stage_and_tau(
|
||||
40, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 1
|
||||
# tau = (40 - 30) / (49 - 30) = 10/19
|
||||
assert abs(tau - 10 / 19) < 1e-6
|
||||
|
||||
# Frame 75 in "lift" stage (index 2)
|
||||
stage, tau = find_stage_and_tau(
|
||||
75, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 2
|
||||
# tau = (75 - 50) / (99 - 50) = 25/49
|
||||
assert abs(tau - 25 / 49) < 1e-6
|
||||
|
||||
def test_frame_at_subtask_boundaries(self):
|
||||
"""Test frames exactly at subtask boundaries."""
|
||||
global_names = ["a", "b"]
|
||||
proportions = {"a": 0.5, "b": 0.5}
|
||||
|
||||
subtask_names = ["a", "b"]
|
||||
subtask_starts = [0, 50]
|
||||
subtask_ends = [49, 99]
|
||||
|
||||
# Frame at start of first subtask
|
||||
stage, tau = find_stage_and_tau(
|
||||
0, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 0
|
||||
assert tau == 0.0
|
||||
|
||||
# Frame at end of first subtask
|
||||
stage, tau = find_stage_and_tau(
|
||||
49, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 0
|
||||
assert tau == 1.0
|
||||
|
||||
# Frame at start of second subtask
|
||||
stage, tau = find_stage_and_tau(
|
||||
50, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 1
|
||||
assert tau == 0.0
|
||||
|
||||
def test_frame_after_last_subtask(self):
|
||||
"""Frames after last subtask should return last stage with high tau."""
|
||||
global_names = ["a", "b"]
|
||||
proportions = {"a": 0.5, "b": 0.5}
|
||||
|
||||
subtask_names = ["a", "b"]
|
||||
subtask_starts = [0, 30]
|
||||
subtask_ends = [29, 59]
|
||||
|
||||
# Frame 80 is after last subtask
|
||||
stage, tau = find_stage_and_tau(
|
||||
80, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 1 # Last stage
|
||||
assert tau == 0.999 # Nearly complete
|
||||
|
||||
|
||||
class TestEndToEndProgressLabeling:
|
||||
"""End-to-end tests for progress label computation using normalize_stage_tau."""
|
||||
|
||||
def test_consistent_semantic_meaning(self):
|
||||
"""Test that same subtask completion maps to same progress across trajectories.
|
||||
|
||||
This is the key semantic property: "end of subtask 1" should always
|
||||
mean the same progress value regardless of trajectory speed.
|
||||
"""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
|
||||
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
|
||||
tau_fast = compute_tau(30, 0, 30) # = 1.0
|
||||
y_fast = normalize_stage_tau(0 + tau_fast, temporal_proportions=proportions)
|
||||
|
||||
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
|
||||
tau_slow = compute_tau(90, 0, 90) # = 1.0
|
||||
y_slow = normalize_stage_tau(0 + tau_slow, temporal_proportions=proportions)
|
||||
|
||||
# Both should map to same progress (0.3 = end of subtask 1)
|
||||
assert abs(y_fast - y_slow) < 1e-6
|
||||
assert abs(y_fast - 0.3) < 1e-6
|
||||
|
||||
def test_monotonic_within_subtask(self):
|
||||
"""Test that progress is monotonically increasing within a subtask."""
|
||||
proportions = [0.4, 0.6]
|
||||
|
||||
prev_y = -1
|
||||
for tau in np.linspace(0, 1, 11):
|
||||
y = normalize_stage_tau(0 + tau, temporal_proportions=proportions)
|
||||
assert y > prev_y or (tau == 0 and y == 0)
|
||||
prev_y = y
|
||||
|
||||
def test_continuous_across_subtasks(self):
|
||||
"""Test that progress is continuous at subtask boundaries."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
|
||||
# End of subtask 0 (stage=0, tau=1.0) -> stage.tau = 1.0
|
||||
y_end_0 = normalize_stage_tau(0 + 1.0, temporal_proportions=proportions)
|
||||
|
||||
# Start of subtask 1 (stage=1, tau=0.0) -> stage.tau = 1.0
|
||||
y_start_1 = normalize_stage_tau(1 + 0.0, temporal_proportions=proportions)
|
||||
|
||||
# Should be equal (P_1 = 0.3)
|
||||
assert abs(y_end_0 - y_start_1) < 1e-6
|
||||
|
||||
# End of subtask 1 (stage=1, tau=1.0) -> stage.tau = 2.0
|
||||
y_end_1 = normalize_stage_tau(1 + 1.0, temporal_proportions=proportions)
|
||||
|
||||
# Start of subtask 2 (stage=2, tau=0.0) -> stage.tau = 2.0
|
||||
y_start_2 = normalize_stage_tau(2 + 0.0, temporal_proportions=proportions)
|
||||
|
||||
# Should be equal (P_2 = 0.8)
|
||||
assert abs(y_end_1 - y_start_2) < 1e-6
|
||||
|
||||
|
||||
class TestTemporalProportionsToBreakpoints:
|
||||
"""Tests for temporal_proportions_to_breakpoints.
|
||||
|
||||
Converts temporal proportions to cumulative breakpoints for normalization.
|
||||
Example: [0.3, 0.5, 0.2] -> [0.0, 0.3, 0.8, 1.0]
|
||||
"""
|
||||
|
||||
def test_basic_conversion(self):
|
||||
"""Test basic conversion from proportions to breakpoints."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||
|
||||
assert breakpoints is not None
|
||||
assert len(breakpoints) == 4
|
||||
assert breakpoints[0] == 0.0
|
||||
assert abs(breakpoints[1] - 0.3) < 1e-6
|
||||
assert abs(breakpoints[2] - 0.8) < 1e-6
|
||||
assert breakpoints[3] == 1.0
|
||||
|
||||
def test_dict_input(self):
|
||||
"""Test with dict input."""
|
||||
proportions = {"a": 0.25, "b": 0.25, "c": 0.5}
|
||||
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||
|
||||
assert breakpoints is not None
|
||||
assert len(breakpoints) == 4
|
||||
assert breakpoints[0] == 0.0
|
||||
assert breakpoints[-1] == 1.0
|
||||
|
||||
def test_dict_with_subtask_names_order(self):
|
||||
"""Test that subtask_names determines order for dict input."""
|
||||
proportions = {"c": 0.5, "a": 0.2, "b": 0.3} # Dict order
|
||||
subtask_names = ["a", "b", "c"] # Different order
|
||||
|
||||
breakpoints = temporal_proportions_to_breakpoints(proportions, subtask_names)
|
||||
|
||||
# Breakpoints should follow subtask_names order: a=0.2, b=0.3, c=0.5
|
||||
assert abs(breakpoints[1] - 0.2) < 1e-6 # a
|
||||
assert abs(breakpoints[2] - 0.5) < 1e-6 # a + b = 0.5
|
||||
assert breakpoints[3] == 1.0 # a + b + c = 1.0
|
||||
|
||||
def test_uniform_proportions(self):
|
||||
"""Test with uniform proportions."""
|
||||
proportions = [0.25, 0.25, 0.25, 0.25]
|
||||
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||
|
||||
expected = [0.0, 0.25, 0.5, 0.75, 1.0]
|
||||
for i, (bp, exp) in enumerate(zip(breakpoints, expected, strict=True)):
|
||||
assert abs(bp - exp) < 1e-6, f"Breakpoint {i} mismatch"
|
||||
|
||||
def test_none_input(self):
|
||||
"""Test that None input returns None."""
|
||||
result = temporal_proportions_to_breakpoints(None)
|
||||
assert result is None
|
||||
|
||||
def test_normalization(self):
|
||||
"""Test that non-normalized proportions are normalized."""
|
||||
# Proportions sum to 2.0, not 1.0
|
||||
proportions = [0.6, 1.0, 0.4]
|
||||
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||
|
||||
# Should be normalized: [0.3, 0.5, 0.2] -> [0, 0.3, 0.8, 1.0]
|
||||
assert breakpoints[-1] == 1.0
|
||||
assert abs(breakpoints[1] - 0.3) < 1e-6
|
||||
|
||||
|
||||
class TestNormalizeStageTau:
|
||||
"""Tests for normalize_stage_tau.
|
||||
|
||||
Normalizes stage+tau values to [0, 1] using breakpoints.
|
||||
"""
|
||||
|
||||
def test_linear_fallback(self):
|
||||
"""Test linear normalization when only num_stages is provided."""
|
||||
# 4 stages, linear: [0, 0.25, 0.5, 0.75, 1.0]
|
||||
|
||||
# Stage 0 start
|
||||
assert normalize_stage_tau(0.0, num_stages=4) == 0.0
|
||||
|
||||
# Stage 0 end / Stage 1 start
|
||||
assert abs(normalize_stage_tau(1.0, num_stages=4) - 0.25) < 1e-6
|
||||
|
||||
# Stage 1 middle
|
||||
assert abs(normalize_stage_tau(1.5, num_stages=4) - 0.375) < 1e-6
|
||||
|
||||
# Stage 3 end
|
||||
assert normalize_stage_tau(4.0, num_stages=4) == 1.0
|
||||
|
||||
def test_with_custom_breakpoints(self):
|
||||
"""Test with custom breakpoints."""
|
||||
# Non-linear breakpoints
|
||||
breakpoints = [0.0, 0.1, 0.5, 1.0] # 3 stages
|
||||
|
||||
# Stage 0: maps [0, 1) to [0.0, 0.1)
|
||||
assert abs(normalize_stage_tau(0.5, breakpoints=breakpoints) - 0.05) < 1e-6
|
||||
|
||||
# Stage 1: maps [1, 2) to [0.1, 0.5)
|
||||
assert abs(normalize_stage_tau(1.5, breakpoints=breakpoints) - 0.3) < 1e-6
|
||||
|
||||
# Stage 2: maps [2, 3) to [0.5, 1.0)
|
||||
assert abs(normalize_stage_tau(2.5, breakpoints=breakpoints) - 0.75) < 1e-6
|
||||
|
||||
def test_with_temporal_proportions(self):
|
||||
"""Test with temporal proportions (auto-computed breakpoints)."""
|
||||
proportions = {"a": 0.2, "b": 0.3, "c": 0.5}
|
||||
subtask_names = ["a", "b", "c"]
|
||||
|
||||
# Stage 0 end should map to 0.2
|
||||
result = normalize_stage_tau(1.0, temporal_proportions=proportions, subtask_names=subtask_names)
|
||||
assert abs(result - 0.2) < 1e-6
|
||||
|
||||
# Stage 1 end should map to 0.5
|
||||
result = normalize_stage_tau(2.0, temporal_proportions=proportions, subtask_names=subtask_names)
|
||||
assert abs(result - 0.5) < 1e-6
|
||||
|
||||
def test_tensor_input(self):
|
||||
"""Test with tensor input."""
|
||||
x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0])
|
||||
breakpoints = [0.0, 0.3, 0.8, 1.0] # 3 stages
|
||||
|
||||
result = normalize_stage_tau(x, breakpoints=breakpoints)
|
||||
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.shape == x.shape
|
||||
assert abs(result[0].item() - 0.0) < 1e-6
|
||||
assert abs(result[2].item() - 0.3) < 1e-6 # End of stage 0
|
||||
assert abs(result[4].item() - 0.8) < 1e-6 # End of stage 1
|
||||
|
||||
def test_clamping(self):
|
||||
"""Test that output is clamped to [0, 1]."""
|
||||
# Below 0
|
||||
assert normalize_stage_tau(-0.5, num_stages=4) == 0.0
|
||||
|
||||
# Above num_stages
|
||||
assert normalize_stage_tau(5.0, num_stages=4) == 1.0
|
||||
|
||||
def test_batch_tensor(self):
|
||||
"""Test with batched tensor."""
|
||||
x = torch.tensor([[0.0, 1.0, 2.0], [0.5, 1.5, 2.5]]) # (2, 3)
|
||||
|
||||
result = normalize_stage_tau(x, num_stages=3)
|
||||
|
||||
assert result.shape == (2, 3)
|
||||
assert (result >= 0).all()
|
||||
assert (result <= 1).all()
|
||||
|
||||
def test_requires_one_of_inputs(self):
|
||||
"""Test that at least one input method is required."""
|
||||
with pytest.raises(ValueError):
|
||||
normalize_stage_tau(1.0)
|
||||
|
||||
|
||||
class TestRewindAugmentation:
|
||||
"""Tests for rewind augmentation logic with bidirectional observation sampling.
|
||||
|
||||
Rewind appends frames before the earliest observation frame, going backwards.
|
||||
With bidirectional sampling centered at frame_idx:
|
||||
- Earliest obs frame = frame_idx - half_steps * frame_gap
|
||||
- Rewind goes backwards from that point
|
||||
"""
|
||||
|
||||
def test_rewind_indices_go_backwards_from_earliest_obs(self):
|
||||
"""Rewind indices should go backwards from earliest observation frame."""
|
||||
frame_idx = 300 # Center of bidirectional window
|
||||
ep_start = 0
|
||||
n_obs_steps = 4 # half_steps = 2
|
||||
frame_gap = 30
|
||||
|
||||
# Earliest obs frame = 300 - 2*30 = 240
|
||||
# Rewind goes backwards: 210, 180
|
||||
rewind_step, rewind_indices = apply_rewind_augmentation(
|
||||
frame_idx,
|
||||
ep_start,
|
||||
n_obs_steps=n_obs_steps,
|
||||
max_rewind_steps=2,
|
||||
frame_gap=frame_gap,
|
||||
rewind_step=2,
|
||||
)
|
||||
|
||||
assert rewind_step == 2
|
||||
assert len(rewind_indices) == 2
|
||||
# First rewind frame is closest to obs window, second is further back
|
||||
assert rewind_indices[0] == 210 # 240 - 30
|
||||
assert rewind_indices[1] == 180 # 240 - 60
|
||||
assert rewind_indices[0] > rewind_indices[1], "Rewind should be descending"
|
||||
|
||||
def test_rewind_goes_backward_through_history(self):
|
||||
"""Rewind frames should go backward before the observation window."""
|
||||
frame_idx = 450 # Center of bidirectional window
|
||||
ep_start = 0
|
||||
n_obs_steps = 8 # half_steps = 4
|
||||
frame_gap = 30
|
||||
|
||||
# Earliest obs frame = 450 - 4*30 = 330
|
||||
# Rewind from 330: [300, 270, 240]
|
||||
rewind_step, rewind_indices = apply_rewind_augmentation(
|
||||
frame_idx,
|
||||
ep_start,
|
||||
n_obs_steps=n_obs_steps,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=frame_gap,
|
||||
rewind_step=3,
|
||||
)
|
||||
|
||||
assert rewind_step == 3
|
||||
expected = [300, 270, 240] # Going backwards from 330
|
||||
assert rewind_indices == expected
|
||||
|
||||
def test_no_rewind_when_obs_window_at_episode_start(self):
|
||||
"""No rewind when observation window reaches episode start."""
|
||||
frame_idx = 120 # Center of window
|
||||
ep_start = 0
|
||||
n_obs_steps = 8 # half_steps = 4
|
||||
frame_gap = 30
|
||||
|
||||
# Earliest obs frame = 120 - 4*30 = 0 (at episode start)
|
||||
rewind_step, rewind_indices = apply_rewind_augmentation(
|
||||
frame_idx, ep_start, n_obs_steps=n_obs_steps, max_rewind_steps=4, frame_gap=frame_gap
|
||||
)
|
||||
|
||||
# No room for rewind
|
||||
assert rewind_step == 0
|
||||
assert rewind_indices == []
|
||||
|
||||
def test_rewind_targets_are_decreasing(self):
|
||||
"""Progress targets for rewind frames should be decreasing."""
|
||||
# Simulate progress values
|
||||
obs_progress = [0.1, 0.2, 0.3, 0.4, 0.5] # Forward progress
|
||||
|
||||
# Rewind reverses progress
|
||||
rewind_indices = [4, 3, 2] # Go backwards through indices
|
||||
rewind_progress = [obs_progress[i] for i in rewind_indices]
|
||||
|
||||
# Should be decreasing
|
||||
for i in range(len(rewind_progress) - 1):
|
||||
assert rewind_progress[i] > rewind_progress[i + 1]
|
||||
Reference in New Issue
Block a user