mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
Reward models refactor (#3142)
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes * refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/ * refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/ * refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py * refactor(rewards): update imports and delete old reward model locations * test(rewards): add reward model tests and update existing test imports * fix(rewards): restore full Classifier and SARM implementations * test(rewards): restore missing CUDA and mixed precision classifier processor tests * refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train * refactor(lerobot_train.py): add missing sampling weight script * linter + missing files * add testing for sampl weighter * revert some useless changes, improve typing * update docs * add automatic detection of the progress path * remove type exp * improve comment * fix: move rabc.py to rewards/sarm/ and update import paths * refactor(imports): update reward model imports to new module structure * refactor(imports): update reward model imports to reflect new module structure * refactor(imports): conditionally import pandas based on availability * feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig * refactor(policies): remove reward model branches from policy factory and __init__ * refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash * feat(train): route reward model training through rewards/factory instead of policies/factory * refactor(train): streamline reward model training logic * fix(rewards): ensure FileNotFoundError is raised for missing config_file * refactor(train): update __get_path_fields__ to include reward_model for config loading * refactor(classifier): remove redundant input normalization in predict_reward method * fix(train): raise ValueError for non-trainable reward models in train function * refactor(pretrained_rm): add model card template * refactor(tests): reward models * refactor(sarm): update reset method and remove unused action prediction methods * refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function * fix(train): raise ValueError for PEFT usage in reward model training * refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -0,0 +1,401 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the sample weighting infrastructure."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.utils.sample_weighting import (
|
||||
SampleWeighter,
|
||||
SampleWeightingConfig,
|
||||
UniformWeighter,
|
||||
make_sample_weighter,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_progress_parquet(tmp_path):
|
||||
"""Create a sample progress parquet file for testing."""
|
||||
import pandas as pd
|
||||
|
||||
# Create sample progress data for 2 episodes with 10 frames each
|
||||
data = {
|
||||
"index": list(range(20)),
|
||||
"episode_index": [0] * 10 + [1] * 10,
|
||||
"frame_index": list(range(10)) * 2,
|
||||
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
|
||||
}
|
||||
df = pd.DataFrame(data)
|
||||
parquet_path = tmp_path / "sarm_progress.parquet"
|
||||
df.to_parquet(parquet_path)
|
||||
return parquet_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SampleWeightingConfig Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_config_default_values():
|
||||
"""Test default configuration values."""
|
||||
config = SampleWeightingConfig()
|
||||
assert config.type == "rabc"
|
||||
assert config.progress_path is None
|
||||
assert config.head_mode == "sparse"
|
||||
assert config.kappa == 0.01
|
||||
assert config.epsilon == 1e-6
|
||||
assert config.extra_params == {}
|
||||
|
||||
|
||||
def test_config_custom_values():
|
||||
"""Test configuration with custom values."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path="/path/to/progress.parquet",
|
||||
head_mode="dense",
|
||||
kappa=0.05,
|
||||
epsilon=1e-8,
|
||||
extra_params={"fallback_weight": 0.5},
|
||||
)
|
||||
assert config.type == "rabc"
|
||||
assert config.progress_path == "/path/to/progress.parquet"
|
||||
assert config.head_mode == "dense"
|
||||
assert config.kappa == 0.05
|
||||
assert config.epsilon == 1e-8
|
||||
assert config.extra_params == {"fallback_weight": 0.5}
|
||||
|
||||
|
||||
def test_config_uniform_type():
|
||||
"""Test configuration for uniform weighting."""
|
||||
config = SampleWeightingConfig(type="uniform")
|
||||
assert config.type == "uniform"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# UniformWeighter Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_uniform_weighter_inherits_from_sample_weighter():
|
||||
"""Test that UniformWeighter is a SampleWeighter."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_with_action_key():
|
||||
"""Test weight computation with 'action' key in batch."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"action": torch.randn(8, 10)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (8,)
|
||||
assert torch.allclose(weights, torch.ones(8))
|
||||
assert stats["mean_weight"] == 1.0
|
||||
assert stats["type"] == "uniform"
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_with_index_key():
|
||||
"""Test weight computation with 'index' key in batch."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"index": torch.arange(16)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (16,)
|
||||
assert torch.allclose(weights, torch.ones(16))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
|
||||
"""Test weight computation with no tensor keys (fallback to size 1)."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"other_key": "some_value"}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (1,)
|
||||
assert torch.allclose(weights, torch.ones(1))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
|
||||
"""Test that empty batch raises ValueError."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {}
|
||||
|
||||
with pytest.raises(ValueError, match="empty batch"):
|
||||
weighter.compute_batch_weights(batch)
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
|
||||
"""Test that batch size is determined by scanning all tensor values."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
# Batch with non-standard key containing a tensor
|
||||
batch = {"custom_tensor": torch.randn(7, 3)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (7,)
|
||||
assert torch.allclose(weights, torch.ones(7))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_on_cuda():
|
||||
"""Test that weights are placed on the correct device."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
weighter = UniformWeighter(device=torch.device("cuda"))
|
||||
batch = {"action": torch.randn(4, 10)}
|
||||
|
||||
weights, _ = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.device.type == "cuda"
|
||||
|
||||
|
||||
def test_uniform_weighter_get_stats():
|
||||
"""Test get_stats returns expected structure."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
stats = weighter.get_stats()
|
||||
|
||||
assert stats == {"type": "uniform"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# make_sample_weighter Factory Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_factory_returns_none_for_none_config():
|
||||
"""Test that None config returns None weighter."""
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
result = make_sample_weighter(None, policy, device)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_factory_creates_uniform_weighter():
|
||||
"""Test creation of UniformWeighter."""
|
||||
config = SampleWeightingConfig(type="uniform")
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
weighter = make_sample_weighter(config, policy, device)
|
||||
|
||||
assert isinstance(weighter, UniformWeighter)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_factory_raises_for_unknown_type():
|
||||
"""Test that unknown type raises ValueError."""
|
||||
config = SampleWeightingConfig(type="unknown_type")
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown sample weighting type"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_chunk_size():
|
||||
"""Test that RABC weighter requires chunk_size in policy config."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path="/path/to/progress.parquet",
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = None # No chunk_size
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="chunk_size"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_progress_path_or_dataset_info():
|
||||
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # No progress path
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Should fail when no progress_path AND no dataset info
|
||||
with pytest.raises(ValueError, match="progress_path"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
|
||||
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # Not provided, should auto-detect
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 5
|
||||
device = torch.device("cpu")
|
||||
|
||||
# The parquet file is at sample_progress_parquet, get its parent directory
|
||||
dataset_root = sample_progress_parquet.parent
|
||||
weighter = make_sample_weighter(
|
||||
config,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=str(dataset_root),
|
||||
)
|
||||
|
||||
assert weighter is not None
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
assert isinstance(weighter, RABCWeights)
|
||||
|
||||
|
||||
def test_factory_rabc_auto_detects_from_repo_id():
|
||||
"""Test that RABC weighter constructs HF path from repo_id."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # Not provided, should auto-detect
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
# This will construct the path but fail when trying to load (file doesn't exist)
|
||||
# We just verify it doesn't raise the "progress_path required" error
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
make_sample_weighter(
|
||||
config,
|
||||
policy,
|
||||
device,
|
||||
dataset_repo_id="test-user/test-dataset",
|
||||
)
|
||||
# Should NOT be the "progress_path required" error - it should try to load the file
|
||||
assert (
|
||||
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests with RABCWeights
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
|
||||
"""Test that RABCWeights inherits from SampleWeighter."""
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_rabc_compute_batch_weights(sample_progress_parquet):
|
||||
"""Test RABCWeights.compute_batch_weights returns correct structure."""
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert isinstance(weights, torch.Tensor)
|
||||
assert weights.shape == (4,)
|
||||
assert isinstance(stats, dict)
|
||||
assert "mean_weight" in stats
|
||||
|
||||
|
||||
def test_rabc_get_stats(sample_progress_parquet):
|
||||
"""Test RABCWeights.get_stats returns expected structure."""
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
)
|
||||
|
||||
stats = weighter.get_stats()
|
||||
|
||||
assert stats["type"] == "rabc"
|
||||
assert "num_frames" in stats
|
||||
assert "chunk_size" in stats
|
||||
assert stats["chunk_size"] == 5
|
||||
assert "head_mode" in stats
|
||||
assert stats["head_mode"] == "sparse"
|
||||
assert "delta_mean" in stats
|
||||
assert "delta_std" in stats
|
||||
|
||||
|
||||
def test_factory_creates_rabc_weighter(sample_progress_parquet):
|
||||
"""Test factory creates RABCWeights with valid config."""
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=str(sample_progress_parquet),
|
||||
head_mode="sparse",
|
||||
kappa=0.01,
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 5
|
||||
device = torch.device("cpu")
|
||||
|
||||
weighter = make_sample_weighter(config, policy, device)
|
||||
|
||||
assert isinstance(weighter, RABCWeights)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_rabc_weights_normalization(sample_progress_parquet):
|
||||
"""Test that RABCWeights normalizes weights to sum to batch_size."""
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||
weights, _ = weighter.compute_batch_weights(batch)
|
||||
|
||||
# Weights should be normalized to sum approximately to batch_size
|
||||
batch_size = 4
|
||||
assert abs(weights.sum().item() - batch_size) < 0.1
|
||||
Reference in New Issue
Block a user