mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
8a3d64033f
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes * refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/ * refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/ * refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py * refactor(rewards): update imports and delete old reward model locations * test(rewards): add reward model tests and update existing test imports * fix(rewards): restore full Classifier and SARM implementations * test(rewards): restore missing CUDA and mixed precision classifier processor tests * refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train * refactor(lerobot_train.py): add missing sampling weight script * linter + missing files * add testing for sampl weighter * revert some useless changes, improve typing * update docs * add automatic detection of the progress path * remove type exp * improve comment * fix: move rabc.py to rewards/sarm/ and update import paths * refactor(imports): update reward model imports to new module structure * refactor(imports): update reward model imports to reflect new module structure * refactor(imports): conditionally import pandas based on availability * feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig * refactor(policies): remove reward model branches from policy factory and __init__ * refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash * feat(train): route reward model training through rewards/factory instead of policies/factory * refactor(train): streamline reward model training logic * fix(rewards): ensure FileNotFoundError is raised for missing config_file * refactor(train): update __get_path_fields__ to include reward_model for config loading * refactor(classifier): remove redundant input normalization in predict_reward method * fix(train): raise ValueError for non-trainable reward models in train function * refactor(pretrained_rm): add model card template * refactor(tests): reward models * refactor(sarm): update reset method and remove unused action prediction methods * refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function * fix(train): raise ValueError for PEFT usage in reward model training * refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
402 lines
13 KiB
Python
402 lines
13 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for the sample weighting infrastructure."""
|
|
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
|
|
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
|
|
|
import torch
|
|
|
|
from lerobot.utils.sample_weighting import (
|
|
SampleWeighter,
|
|
SampleWeightingConfig,
|
|
UniformWeighter,
|
|
make_sample_weighter,
|
|
)
|
|
|
|
# =============================================================================
|
|
# Fixtures
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_progress_parquet(tmp_path):
|
|
"""Create a sample progress parquet file for testing."""
|
|
import pandas as pd
|
|
|
|
# Create sample progress data for 2 episodes with 10 frames each
|
|
data = {
|
|
"index": list(range(20)),
|
|
"episode_index": [0] * 10 + [1] * 10,
|
|
"frame_index": list(range(10)) * 2,
|
|
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
|
|
}
|
|
df = pd.DataFrame(data)
|
|
parquet_path = tmp_path / "sarm_progress.parquet"
|
|
df.to_parquet(parquet_path)
|
|
return parquet_path
|
|
|
|
|
|
# =============================================================================
|
|
# SampleWeightingConfig Tests
|
|
# =============================================================================
|
|
|
|
|
|
def test_config_default_values():
|
|
"""Test default configuration values."""
|
|
config = SampleWeightingConfig()
|
|
assert config.type == "rabc"
|
|
assert config.progress_path is None
|
|
assert config.head_mode == "sparse"
|
|
assert config.kappa == 0.01
|
|
assert config.epsilon == 1e-6
|
|
assert config.extra_params == {}
|
|
|
|
|
|
def test_config_custom_values():
|
|
"""Test configuration with custom values."""
|
|
config = SampleWeightingConfig(
|
|
type="rabc",
|
|
progress_path="/path/to/progress.parquet",
|
|
head_mode="dense",
|
|
kappa=0.05,
|
|
epsilon=1e-8,
|
|
extra_params={"fallback_weight": 0.5},
|
|
)
|
|
assert config.type == "rabc"
|
|
assert config.progress_path == "/path/to/progress.parquet"
|
|
assert config.head_mode == "dense"
|
|
assert config.kappa == 0.05
|
|
assert config.epsilon == 1e-8
|
|
assert config.extra_params == {"fallback_weight": 0.5}
|
|
|
|
|
|
def test_config_uniform_type():
|
|
"""Test configuration for uniform weighting."""
|
|
config = SampleWeightingConfig(type="uniform")
|
|
assert config.type == "uniform"
|
|
|
|
|
|
# =============================================================================
|
|
# UniformWeighter Tests
|
|
# =============================================================================
|
|
|
|
|
|
def test_uniform_weighter_inherits_from_sample_weighter():
|
|
"""Test that UniformWeighter is a SampleWeighter."""
|
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
|
assert isinstance(weighter, SampleWeighter)
|
|
|
|
|
|
def test_uniform_weighter_compute_batch_weights_with_action_key():
|
|
"""Test weight computation with 'action' key in batch."""
|
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
|
batch = {"action": torch.randn(8, 10)}
|
|
|
|
weights, stats = weighter.compute_batch_weights(batch)
|
|
|
|
assert weights.shape == (8,)
|
|
assert torch.allclose(weights, torch.ones(8))
|
|
assert stats["mean_weight"] == 1.0
|
|
assert stats["type"] == "uniform"
|
|
|
|
|
|
def test_uniform_weighter_compute_batch_weights_with_index_key():
|
|
"""Test weight computation with 'index' key in batch."""
|
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
|
batch = {"index": torch.arange(16)}
|
|
|
|
weights, stats = weighter.compute_batch_weights(batch)
|
|
|
|
assert weights.shape == (16,)
|
|
assert torch.allclose(weights, torch.ones(16))
|
|
|
|
|
|
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
|
|
"""Test weight computation with no tensor keys (fallback to size 1)."""
|
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
|
batch = {"other_key": "some_value"}
|
|
|
|
weights, stats = weighter.compute_batch_weights(batch)
|
|
|
|
assert weights.shape == (1,)
|
|
assert torch.allclose(weights, torch.ones(1))
|
|
|
|
|
|
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
|
|
"""Test that empty batch raises ValueError."""
|
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
|
batch = {}
|
|
|
|
with pytest.raises(ValueError, match="empty batch"):
|
|
weighter.compute_batch_weights(batch)
|
|
|
|
|
|
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
|
|
"""Test that batch size is determined by scanning all tensor values."""
|
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
|
# Batch with non-standard key containing a tensor
|
|
batch = {"custom_tensor": torch.randn(7, 3)}
|
|
|
|
weights, stats = weighter.compute_batch_weights(batch)
|
|
|
|
assert weights.shape == (7,)
|
|
assert torch.allclose(weights, torch.ones(7))
|
|
|
|
|
|
def test_uniform_weighter_compute_batch_weights_on_cuda():
|
|
"""Test that weights are placed on the correct device."""
|
|
if not torch.cuda.is_available():
|
|
pytest.skip("CUDA not available")
|
|
|
|
weighter = UniformWeighter(device=torch.device("cuda"))
|
|
batch = {"action": torch.randn(4, 10)}
|
|
|
|
weights, _ = weighter.compute_batch_weights(batch)
|
|
|
|
assert weights.device.type == "cuda"
|
|
|
|
|
|
def test_uniform_weighter_get_stats():
|
|
"""Test get_stats returns expected structure."""
|
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
|
stats = weighter.get_stats()
|
|
|
|
assert stats == {"type": "uniform"}
|
|
|
|
|
|
# =============================================================================
|
|
# make_sample_weighter Factory Tests
|
|
# =============================================================================
|
|
|
|
|
|
def test_factory_returns_none_for_none_config():
|
|
"""Test that None config returns None weighter."""
|
|
policy = Mock()
|
|
device = torch.device("cpu")
|
|
|
|
result = make_sample_weighter(None, policy, device)
|
|
|
|
assert result is None
|
|
|
|
|
|
def test_factory_creates_uniform_weighter():
|
|
"""Test creation of UniformWeighter."""
|
|
config = SampleWeightingConfig(type="uniform")
|
|
policy = Mock()
|
|
device = torch.device("cpu")
|
|
|
|
weighter = make_sample_weighter(config, policy, device)
|
|
|
|
assert isinstance(weighter, UniformWeighter)
|
|
assert isinstance(weighter, SampleWeighter)
|
|
|
|
|
|
def test_factory_raises_for_unknown_type():
|
|
"""Test that unknown type raises ValueError."""
|
|
config = SampleWeightingConfig(type="unknown_type")
|
|
policy = Mock()
|
|
device = torch.device("cpu")
|
|
|
|
with pytest.raises(ValueError, match="Unknown sample weighting type"):
|
|
make_sample_weighter(config, policy, device)
|
|
|
|
|
|
def test_factory_rabc_requires_chunk_size():
|
|
"""Test that RABC weighter requires chunk_size in policy config."""
|
|
config = SampleWeightingConfig(
|
|
type="rabc",
|
|
progress_path="/path/to/progress.parquet",
|
|
)
|
|
policy = Mock()
|
|
policy.config = Mock()
|
|
policy.config.chunk_size = None # No chunk_size
|
|
device = torch.device("cpu")
|
|
|
|
with pytest.raises(ValueError, match="chunk_size"):
|
|
make_sample_weighter(config, policy, device)
|
|
|
|
|
|
def test_factory_rabc_requires_progress_path_or_dataset_info():
|
|
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
|
|
config = SampleWeightingConfig(
|
|
type="rabc",
|
|
progress_path=None, # No progress path
|
|
)
|
|
policy = Mock()
|
|
policy.config = Mock()
|
|
policy.config.chunk_size = 50
|
|
device = torch.device("cpu")
|
|
|
|
# Should fail when no progress_path AND no dataset info
|
|
with pytest.raises(ValueError, match="progress_path"):
|
|
make_sample_weighter(config, policy, device)
|
|
|
|
|
|
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
|
|
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
|
|
config = SampleWeightingConfig(
|
|
type="rabc",
|
|
progress_path=None, # Not provided, should auto-detect
|
|
)
|
|
policy = Mock()
|
|
policy.config = Mock()
|
|
policy.config.chunk_size = 5
|
|
device = torch.device("cpu")
|
|
|
|
# The parquet file is at sample_progress_parquet, get its parent directory
|
|
dataset_root = sample_progress_parquet.parent
|
|
weighter = make_sample_weighter(
|
|
config,
|
|
policy,
|
|
device,
|
|
dataset_root=str(dataset_root),
|
|
)
|
|
|
|
assert weighter is not None
|
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
|
|
|
assert isinstance(weighter, RABCWeights)
|
|
|
|
|
|
def test_factory_rabc_auto_detects_from_repo_id():
|
|
"""Test that RABC weighter constructs HF path from repo_id."""
|
|
config = SampleWeightingConfig(
|
|
type="rabc",
|
|
progress_path=None, # Not provided, should auto-detect
|
|
)
|
|
policy = Mock()
|
|
policy.config = Mock()
|
|
policy.config.chunk_size = 50
|
|
device = torch.device("cpu")
|
|
|
|
# This will construct the path but fail when trying to load (file doesn't exist)
|
|
# We just verify it doesn't raise the "progress_path required" error
|
|
with pytest.raises(Exception) as exc_info:
|
|
make_sample_weighter(
|
|
config,
|
|
policy,
|
|
device,
|
|
dataset_repo_id="test-user/test-dataset",
|
|
)
|
|
# Should NOT be the "progress_path required" error - it should try to load the file
|
|
assert (
|
|
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Integration Tests with RABCWeights
|
|
# =============================================================================
|
|
|
|
|
|
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
|
|
"""Test that RABCWeights inherits from SampleWeighter."""
|
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
|
|
|
weighter = RABCWeights(
|
|
progress_path=sample_progress_parquet,
|
|
chunk_size=5,
|
|
head_mode="sparse",
|
|
)
|
|
assert isinstance(weighter, SampleWeighter)
|
|
|
|
|
|
def test_rabc_compute_batch_weights(sample_progress_parquet):
|
|
"""Test RABCWeights.compute_batch_weights returns correct structure."""
|
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
|
|
|
weighter = RABCWeights(
|
|
progress_path=sample_progress_parquet,
|
|
chunk_size=5,
|
|
head_mode="sparse",
|
|
device=torch.device("cpu"),
|
|
)
|
|
|
|
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
|
weights, stats = weighter.compute_batch_weights(batch)
|
|
|
|
assert isinstance(weights, torch.Tensor)
|
|
assert weights.shape == (4,)
|
|
assert isinstance(stats, dict)
|
|
assert "mean_weight" in stats
|
|
|
|
|
|
def test_rabc_get_stats(sample_progress_parquet):
|
|
"""Test RABCWeights.get_stats returns expected structure."""
|
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
|
|
|
weighter = RABCWeights(
|
|
progress_path=sample_progress_parquet,
|
|
chunk_size=5,
|
|
head_mode="sparse",
|
|
)
|
|
|
|
stats = weighter.get_stats()
|
|
|
|
assert stats["type"] == "rabc"
|
|
assert "num_frames" in stats
|
|
assert "chunk_size" in stats
|
|
assert stats["chunk_size"] == 5
|
|
assert "head_mode" in stats
|
|
assert stats["head_mode"] == "sparse"
|
|
assert "delta_mean" in stats
|
|
assert "delta_std" in stats
|
|
|
|
|
|
def test_factory_creates_rabc_weighter(sample_progress_parquet):
|
|
"""Test factory creates RABCWeights with valid config."""
|
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
|
|
|
config = SampleWeightingConfig(
|
|
type="rabc",
|
|
progress_path=str(sample_progress_parquet),
|
|
head_mode="sparse",
|
|
kappa=0.01,
|
|
)
|
|
policy = Mock()
|
|
policy.config = Mock()
|
|
policy.config.chunk_size = 5
|
|
device = torch.device("cpu")
|
|
|
|
weighter = make_sample_weighter(config, policy, device)
|
|
|
|
assert isinstance(weighter, RABCWeights)
|
|
assert isinstance(weighter, SampleWeighter)
|
|
|
|
|
|
def test_rabc_weights_normalization(sample_progress_parquet):
|
|
"""Test that RABCWeights normalizes weights to sum to batch_size."""
|
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
|
|
|
weighter = RABCWeights(
|
|
progress_path=sample_progress_parquet,
|
|
chunk_size=5,
|
|
head_mode="sparse",
|
|
device=torch.device("cpu"),
|
|
)
|
|
|
|
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
|
weights, _ = weighter.compute_batch_weights(batch)
|
|
|
|
# Weights should be normalized to sum approximately to batch_size
|
|
batch_size = 4
|
|
assert abs(weights.sum().item() - batch_size) < 0.1
|