Files
lerobot/tests/utils/test_sample_weighting.py
T
Khalil Meftah 8a3d64033f Reward models refactor (#3142)
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes

* refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/

* refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/

* refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py

* refactor(rewards): update imports and delete old reward model locations

* test(rewards): add reward model tests and update existing test imports

* fix(rewards): restore full Classifier and SARM implementations

* test(rewards): restore missing CUDA and mixed precision classifier processor tests

* refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train

* refactor(lerobot_train.py): add missing sampling weight script

* linter + missing files

* add testing for sampl weighter

* revert some useless changes, improve typing

* update docs

* add automatic detection of the progress path

* remove type exp

* improve comment

* fix: move rabc.py to rewards/sarm/ and update import paths

* refactor(imports): update reward model imports to new module structure

* refactor(imports): update reward model imports to reflect new module structure

* refactor(imports): conditionally import pandas based on availability

* feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig

* refactor(policies): remove reward model branches from policy factory and __init__

* refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash

* feat(train): route reward model training through rewards/factory instead of policies/factory

* refactor(train): streamline reward model training logic

* fix(rewards): ensure FileNotFoundError is raised for missing config_file

* refactor(train): update __get_path_fields__ to include reward_model for config loading

* refactor(classifier): remove redundant input normalization in predict_reward method

* fix(train): raise ValueError for non-trainable reward models in train function

* refactor(pretrained_rm): add model card template

* refactor(tests): reward models

* refactor(sarm): update reset method and remove unused action prediction methods

* refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function

* fix(train): raise ValueError for PEFT usage in reward model training

* refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2026-04-28 17:56:24 +02:00

402 lines
13 KiB
Python

#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the sample weighting infrastructure."""
from unittest.mock import Mock
import pytest
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import torch
from lerobot.utils.sample_weighting import (
SampleWeighter,
SampleWeightingConfig,
UniformWeighter,
make_sample_weighter,
)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def sample_progress_parquet(tmp_path):
"""Create a sample progress parquet file for testing."""
import pandas as pd
# Create sample progress data for 2 episodes with 10 frames each
data = {
"index": list(range(20)),
"episode_index": [0] * 10 + [1] * 10,
"frame_index": list(range(10)) * 2,
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
}
df = pd.DataFrame(data)
parquet_path = tmp_path / "sarm_progress.parquet"
df.to_parquet(parquet_path)
return parquet_path
# =============================================================================
# SampleWeightingConfig Tests
# =============================================================================
def test_config_default_values():
"""Test default configuration values."""
config = SampleWeightingConfig()
assert config.type == "rabc"
assert config.progress_path is None
assert config.head_mode == "sparse"
assert config.kappa == 0.01
assert config.epsilon == 1e-6
assert config.extra_params == {}
def test_config_custom_values():
"""Test configuration with custom values."""
config = SampleWeightingConfig(
type="rabc",
progress_path="/path/to/progress.parquet",
head_mode="dense",
kappa=0.05,
epsilon=1e-8,
extra_params={"fallback_weight": 0.5},
)
assert config.type == "rabc"
assert config.progress_path == "/path/to/progress.parquet"
assert config.head_mode == "dense"
assert config.kappa == 0.05
assert config.epsilon == 1e-8
assert config.extra_params == {"fallback_weight": 0.5}
def test_config_uniform_type():
"""Test configuration for uniform weighting."""
config = SampleWeightingConfig(type="uniform")
assert config.type == "uniform"
# =============================================================================
# UniformWeighter Tests
# =============================================================================
def test_uniform_weighter_inherits_from_sample_weighter():
"""Test that UniformWeighter is a SampleWeighter."""
weighter = UniformWeighter(device=torch.device("cpu"))
assert isinstance(weighter, SampleWeighter)
def test_uniform_weighter_compute_batch_weights_with_action_key():
"""Test weight computation with 'action' key in batch."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"action": torch.randn(8, 10)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (8,)
assert torch.allclose(weights, torch.ones(8))
assert stats["mean_weight"] == 1.0
assert stats["type"] == "uniform"
def test_uniform_weighter_compute_batch_weights_with_index_key():
"""Test weight computation with 'index' key in batch."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"index": torch.arange(16)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (16,)
assert torch.allclose(weights, torch.ones(16))
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
"""Test weight computation with no tensor keys (fallback to size 1)."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"other_key": "some_value"}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (1,)
assert torch.allclose(weights, torch.ones(1))
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
"""Test that empty batch raises ValueError."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {}
with pytest.raises(ValueError, match="empty batch"):
weighter.compute_batch_weights(batch)
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
"""Test that batch size is determined by scanning all tensor values."""
weighter = UniformWeighter(device=torch.device("cpu"))
# Batch with non-standard key containing a tensor
batch = {"custom_tensor": torch.randn(7, 3)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (7,)
assert torch.allclose(weights, torch.ones(7))
def test_uniform_weighter_compute_batch_weights_on_cuda():
"""Test that weights are placed on the correct device."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
weighter = UniformWeighter(device=torch.device("cuda"))
batch = {"action": torch.randn(4, 10)}
weights, _ = weighter.compute_batch_weights(batch)
assert weights.device.type == "cuda"
def test_uniform_weighter_get_stats():
"""Test get_stats returns expected structure."""
weighter = UniformWeighter(device=torch.device("cpu"))
stats = weighter.get_stats()
assert stats == {"type": "uniform"}
# =============================================================================
# make_sample_weighter Factory Tests
# =============================================================================
def test_factory_returns_none_for_none_config():
"""Test that None config returns None weighter."""
policy = Mock()
device = torch.device("cpu")
result = make_sample_weighter(None, policy, device)
assert result is None
def test_factory_creates_uniform_weighter():
"""Test creation of UniformWeighter."""
config = SampleWeightingConfig(type="uniform")
policy = Mock()
device = torch.device("cpu")
weighter = make_sample_weighter(config, policy, device)
assert isinstance(weighter, UniformWeighter)
assert isinstance(weighter, SampleWeighter)
def test_factory_raises_for_unknown_type():
"""Test that unknown type raises ValueError."""
config = SampleWeightingConfig(type="unknown_type")
policy = Mock()
device = torch.device("cpu")
with pytest.raises(ValueError, match="Unknown sample weighting type"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_chunk_size():
"""Test that RABC weighter requires chunk_size in policy config."""
config = SampleWeightingConfig(
type="rabc",
progress_path="/path/to/progress.parquet",
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = None # No chunk_size
device = torch.device("cpu")
with pytest.raises(ValueError, match="chunk_size"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_progress_path_or_dataset_info():
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # No progress path
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 50
device = torch.device("cpu")
# Should fail when no progress_path AND no dataset info
with pytest.raises(ValueError, match="progress_path"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # Not provided, should auto-detect
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 5
device = torch.device("cpu")
# The parquet file is at sample_progress_parquet, get its parent directory
dataset_root = sample_progress_parquet.parent
weighter = make_sample_weighter(
config,
policy,
device,
dataset_root=str(dataset_root),
)
assert weighter is not None
from lerobot.rewards.sarm.rabc import RABCWeights
assert isinstance(weighter, RABCWeights)
def test_factory_rabc_auto_detects_from_repo_id():
"""Test that RABC weighter constructs HF path from repo_id."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # Not provided, should auto-detect
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 50
device = torch.device("cpu")
# This will construct the path but fail when trying to load (file doesn't exist)
# We just verify it doesn't raise the "progress_path required" error
with pytest.raises(Exception) as exc_info:
make_sample_weighter(
config,
policy,
device,
dataset_repo_id="test-user/test-dataset",
)
# Should NOT be the "progress_path required" error - it should try to load the file
assert (
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
)
# =============================================================================
# Integration Tests with RABCWeights
# =============================================================================
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
"""Test that RABCWeights inherits from SampleWeighter."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
)
assert isinstance(weighter, SampleWeighter)
def test_rabc_compute_batch_weights(sample_progress_parquet):
"""Test RABCWeights.compute_batch_weights returns correct structure."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
device=torch.device("cpu"),
)
batch = {"index": torch.tensor([0, 1, 2, 3])}
weights, stats = weighter.compute_batch_weights(batch)
assert isinstance(weights, torch.Tensor)
assert weights.shape == (4,)
assert isinstance(stats, dict)
assert "mean_weight" in stats
def test_rabc_get_stats(sample_progress_parquet):
"""Test RABCWeights.get_stats returns expected structure."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
)
stats = weighter.get_stats()
assert stats["type"] == "rabc"
assert "num_frames" in stats
assert "chunk_size" in stats
assert stats["chunk_size"] == 5
assert "head_mode" in stats
assert stats["head_mode"] == "sparse"
assert "delta_mean" in stats
assert "delta_std" in stats
def test_factory_creates_rabc_weighter(sample_progress_parquet):
"""Test factory creates RABCWeights with valid config."""
from lerobot.rewards.sarm.rabc import RABCWeights
config = SampleWeightingConfig(
type="rabc",
progress_path=str(sample_progress_parquet),
head_mode="sparse",
kappa=0.01,
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 5
device = torch.device("cpu")
weighter = make_sample_weighter(config, policy, device)
assert isinstance(weighter, RABCWeights)
assert isinstance(weighter, SampleWeighter)
def test_rabc_weights_normalization(sample_progress_parquet):
"""Test that RABCWeights normalizes weights to sum to batch_size."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
device=torch.device("cpu"),
)
batch = {"index": torch.tensor([0, 1, 2, 3])}
weights, _ = weighter.compute_batch_weights(batch)
# Weights should be normalized to sum approximately to batch_size
batch_size = 4
assert abs(weights.sum().item() - batch_size) < 0.1