Reward models refactor (#3142)

* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes

* refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/

* refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/

* refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py

* refactor(rewards): update imports and delete old reward model locations

* test(rewards): add reward model tests and update existing test imports

* fix(rewards): restore full Classifier and SARM implementations

* test(rewards): restore missing CUDA and mixed precision classifier processor tests

* refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train

* refactor(lerobot_train.py): add missing sampling weight script

* linter + missing files

* add testing for sampl weighter

* revert some useless changes, improve typing

* update docs

* add automatic detection of the progress path

* remove type exp

* improve comment

* fix: move rabc.py to rewards/sarm/ and update import paths

* refactor(imports): update reward model imports to new module structure

* refactor(imports): update reward model imports to reflect new module structure

* refactor(imports): conditionally import pandas based on availability

* feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig

* refactor(policies): remove reward model branches from policy factory and __init__

* refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash

* feat(train): route reward model training through rewards/factory instead of policies/factory

* refactor(train): streamline reward model training logic

* fix(rewards): ensure FileNotFoundError is raised for missing config_file

* refactor(train): update __get_path_fields__ to include reward_model for config loading

* refactor(classifier): remove redundant input normalization in predict_reward method

* fix(train): raise ValueError for non-trainable reward models in train function

* refactor(pretrained_rm): add model card template

* refactor(tests): reward models

* refactor(sarm): update reset method and remove unused action prediction methods

* refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function

* fix(train): raise ValueError for PEFT usage in reward model training

* refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Khalil Meftah
2026-04-28 17:56:24 +02:00
committed by GitHub
parent 03ee50e08f
commit 8a3d64033f
37 changed files with 2091 additions and 381 deletions
View File
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,8 +19,6 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
from lerobot.processor import (
DataProcessorPipeline,
DeviceProcessorStep,
@@ -31,6 +27,8 @@ from lerobot.processor import (
TransitionKey,
)
from lerobot.processor.converters import create_transition, transition_to_batch
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.classifier.processor_classifier import make_classifier_processor
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
@@ -42,7 +40,7 @@ def create_default_config():
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), # Classifier output
"reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
}
config.normalization_mapping = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
@@ -90,17 +88,14 @@ def test_classifier_processor_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Create test data
observation = {
OBS_STATE: torch.randn(10),
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(1) # Dummy action/reward
action = torch.randn(1)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
@@ -120,10 +115,7 @@ def test_classifier_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Create CPU data
observation = {
@@ -132,7 +124,6 @@ def test_classifier_processor_cuda():
}
action = torch.randn(1)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
# Process through preprocessor
@@ -158,10 +149,7 @@ def test_classifier_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -171,7 +159,6 @@ def test_classifier_processor_accelerate_scenario():
}
action = torch.randn(1).to(device)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
# Process through preprocessor
@@ -201,7 +188,6 @@ def test_classifier_processor_multi_gpu():
}
action = torch.randn(1).to(device)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
# Process through preprocessor
@@ -231,7 +217,6 @@ def test_classifier_processor_without_stats():
}
action = torch.randn(1)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
processed = preprocessor(batch)
@@ -294,7 +279,6 @@ def test_classifier_processor_mixed_precision():
}
action = torch.randn(1, dtype=torch.float32)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
# Process through preprocessor
@@ -312,10 +296,7 @@ def test_classifier_processor_batch_data():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Test with batched data
batch_size = 16
@@ -325,7 +306,6 @@ def test_classifier_processor_batch_data():
}
action = torch.randn(batch_size, 1)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
# Process through preprocessor
@@ -343,15 +323,11 @@ def test_classifier_processor_postprocessor_identity():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Create test data for postprocessor
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
reward = torch.tensor([[0.8], [0.3], [0.9]])
transition = create_transition(action=reward)
_ = transition_to_batch(transition)
# Process through postprocessor
@@ -1,5 +1,3 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,8 +16,8 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.classifier.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE, REWARD
from tests.utils import skip_if_package_missing
@@ -42,7 +40,7 @@ def test_classifier_output():
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_binary_classifier_with_default_params():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
config = RewardClassifierConfig()
config.input_features = {
@@ -86,7 +84,7 @@ def test_binary_classifier_with_default_params():
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_multiclass_classifier():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
num_classes = 5
config = RewardClassifierConfig()
@@ -128,11 +126,15 @@ def test_multiclass_classifier():
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_default_device():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
config = RewardClassifierConfig()
assert config.device == "cpu"
assert config.device is None or config.device == "cpu"
config.input_features = {
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.num_cameras = 1
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")
@@ -143,11 +145,15 @@ def test_default_device():
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_explicit_device_setup():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
config = RewardClassifierConfig(device="cpu")
assert config.device == "cpu"
config.input_features = {
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.num_cameras = 1
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")
+373
View File
@@ -0,0 +1,373 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the reward model base classes and registry."""
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
import pytest
import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.optim.optimizers import AdamWConfig
from lerobot.rewards.pretrained import PreTrainedRewardModel
@RewardModelConfig.register_subclass(name="_dummy_hub_reward")
@dataclass
class _DummyHubRewardConfig(RewardModelConfig):
def get_optimizer_preset(self):
return AdamWConfig(lr=1e-4)
class _DummyHubReward(PreTrainedRewardModel):
config_class = _DummyHubRewardConfig
name = "_dummy_hub_reward"
def __init__(self, config):
super().__init__(config)
self.bias = torch.nn.Parameter(torch.zeros(1))
def compute_reward(self, batch):
return self.bias.expand(1)
def test_reward_model_config_registry():
"""Verify that classifier and sarm are registered."""
known = RewardModelConfig.get_known_choices()
assert "reward_classifier" in known
assert "sarm" in known
def test_reward_model_config_lookup():
"""Verify that we can look up configs by name."""
cls = RewardModelConfig.get_choice_class("reward_classifier")
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
assert cls is RewardClassifierConfig
def test_factory_get_reward_model_class():
"""Test the get_reward_model_class factory."""
from lerobot.rewards.factory import get_reward_model_class
cls = get_reward_model_class("sarm")
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
assert cls is SARMRewardModel
def test_factory_unknown_raises():
"""Unknown name should raise ValueError."""
from lerobot.rewards.factory import get_reward_model_class
with pytest.raises(ValueError, match="not available"):
get_reward_model_class("nonexistent_reward_model")
def test_pretrained_reward_model_requires_config_class():
"""Subclass without config_class should fail."""
with pytest.raises(TypeError, match="must define 'config_class'"):
class BadModel(PreTrainedRewardModel):
name = "bad"
def compute_reward(self, batch):
pass
def test_pretrained_reward_model_requires_name():
"""Subclass without name should fail."""
with pytest.raises(TypeError, match="must define 'name'"):
class BadModel(PreTrainedRewardModel):
config_class = RewardModelConfig
def compute_reward(self, batch):
pass
def test_non_trainable_forward_raises():
"""Non-trainable model should raise on forward()."""
from dataclasses import dataclass
from lerobot.optim.optimizers import AdamWConfig
@dataclass
class DummyConfig(RewardModelConfig):
def get_optimizer_preset(self):
return AdamWConfig(lr=1e-4)
class DummyReward(PreTrainedRewardModel):
config_class = DummyConfig
name = "dummy_test"
def compute_reward(self, batch):
return torch.zeros(1)
config = DummyConfig()
model = DummyReward(config)
with pytest.raises(NotImplementedError, match="not trainable"):
model.forward({"x": torch.zeros(1)})
# ---------------------------------------------------------------------------
# Trainable vs zero-shot (general-purpose) reward models.
# The proposal explicitly supports models like TOPReward that wrap a pretrained
# VLM and produce a reward signal without any training step. These tests pin
# the contract that lets such models coexist with trainable ones.
# ---------------------------------------------------------------------------
def test_is_trainable_false_when_forward_not_overridden():
"""A reward model that only implements ``compute_reward`` is zero-shot."""
model, _ = _make_dummy_reward_model()
assert model.is_trainable is False
def test_is_trainable_true_when_forward_overridden():
"""Overriding ``forward`` flips ``is_trainable`` to True."""
class _TrainableReward(_DummyHubReward):
name = "_trainable_dummy_reward"
def forward(self, batch):
loss = (self.bias**2).sum()
return loss, {}
# Register a fresh config subclass so the subclass check passes.
@RewardModelConfig.register_subclass(name="_trainable_dummy_reward")
@dataclass
class _TrainableConfig(_DummyHubRewardConfig):
pass
_TrainableReward.config_class = _TrainableConfig
model = _TrainableReward(_TrainableConfig())
assert model.is_trainable is True
# ---------------------------------------------------------------------------
# RewardModelConfig.from_pretrained
# ---------------------------------------------------------------------------
def test_reward_model_config_from_pretrained_raises_when_config_missing(tmp_path):
"""``from_pretrained`` must surface a clear ``FileNotFoundError`` when the
target directory exists but does not contain ``config.json``, instead of
crashing later inside ``draccus.parse``.
"""
# tmp_path exists but has no config.json
with pytest.raises(FileNotFoundError, match="config.json not found"):
RewardModelConfig.from_pretrained(tmp_path)
def test_reward_model_config_from_pretrained_roundtrip(tmp_path):
"""Round-trip: save a RewardClassifierConfig, reload it, fields must match."""
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
original = RewardClassifierConfig(
num_classes=3,
hidden_dim=128,
latent_dim=64,
num_cameras=1,
learning_rate=5e-4,
)
original._save_pretrained(tmp_path)
loaded = RewardModelConfig.from_pretrained(tmp_path)
assert isinstance(loaded, RewardClassifierConfig)
assert loaded.num_classes == 3
assert loaded.hidden_dim == 128
assert loaded.latent_dim == 64
assert loaded.num_cameras == 1
assert loaded.learning_rate == 5e-4
# ---------------------------------------------------------------------------
# TrainPipelineConfig — reward model training path
# ---------------------------------------------------------------------------
def test_train_pipeline_config_path_fields_includes_reward_model():
"""``--reward_model.path=local/dir`` requires ``reward_model`` to be listed
as a draccus path-field on ``TrainPipelineConfig``."""
from lerobot.configs.train import TrainPipelineConfig
fields = TrainPipelineConfig.__get_path_fields__()
assert "policy" in fields
assert "reward_model" in fields
def test_train_pipeline_config_trainable_config_returns_reward_model_when_set():
"""When only ``reward_model`` is set, ``trainable_config`` (used by the
trainer for e.g. ``.device``) must return it — not ``None`` from ``policy``."""
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
reward_cfg = RewardClassifierConfig(device="cpu")
cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/repo"),
reward_model=reward_cfg,
)
assert cfg.is_reward_model_training is True
assert cfg.trainable_config is reward_cfg
# This is what lerobot_train.py uses to decide force_cpu; ``cfg.policy.device``
# would AttributeError here because policy is None.
assert cfg.trainable_config.device == "cpu"
def test_train_pipeline_config_trainable_config_returns_policy_when_set():
"""Mirror of the reward-model case: when only ``policy`` is set,
``trainable_config`` must return it."""
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
policy_cfg = DiffusionConfig(device="cpu")
cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/repo"),
policy=policy_cfg,
)
assert cfg.is_reward_model_training is False
assert cfg.trainable_config is policy_cfg
assert cfg.trainable_config.device == "cpu"
# ---------------------------------------------------------------------------
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
# We test the generation side (offline) fully, and the upload side with HfApi
# mocked so nothing actually hits the network.
# ---------------------------------------------------------------------------
def _make_dummy_reward_model(**config_kwargs):
return _DummyHubReward(_DummyHubRewardConfig(**config_kwargs)), _DummyHubRewardConfig
@pytest.fixture
def _offline_model_card(monkeypatch):
"""``ModelCard.validate`` does a live ``POST`` to huggingface.co — bypass it
so tests can run offline."""
from huggingface_hub import ModelCard
monkeypatch.setattr(ModelCard, "validate", lambda self, *a, **kw: None)
def test_reward_model_generate_model_card_renders_expected_fields(_offline_model_card):
"""``generate_model_card`` must produce a card with the right metadata and
body, using the dedicated reward-model template."""
model, _ = _make_dummy_reward_model(
license="mit",
tags=["robot", "sim"],
)
card = model.generate_model_card(
dataset_repo_id="user/my_dataset",
model_type=model.config.type,
license=model.config.license,
tags=model.config.tags,
)
# Metadata (YAML header) — ModelCardData fields.
assert card.data.license == "mit"
assert card.data.library_name == "lerobot"
assert card.data.pipeline_tag == "robotics"
assert "reward-model" in card.data.tags
assert model.config.type in card.data.tags
assert card.data.model_name == model.config.type
assert card.data.datasets == "user/my_dataset"
# Body — specific to the reward-model template, NOT the policy one.
body = str(card)
assert "Reward Model Card" in body
assert "This reward model has been trained" in body
assert "--reward_model.type=" in body # reward-model-specific usage block
def test_reward_model_generate_model_card_uses_default_license(_offline_model_card):
"""When config.license is None the card falls back to apache-2.0."""
model, _ = _make_dummy_reward_model()
card = model.generate_model_card(
dataset_repo_id="user/my_dataset",
model_type=model.config.type,
license=model.config.license,
tags=None,
)
assert card.data.license == "apache-2.0"
def test_reward_model_push_model_to_hub_uploads_expected_files(monkeypatch, _offline_model_card):
"""``push_model_to_hub`` must:
1. create the repo,
2. assemble a temp folder with weights + config.json + train_config.json + README.md,
3. call ``api.upload_folder`` on that folder.
All network calls are mocked.
"""
from huggingface_hub.constants import CONFIG_NAME
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
model, _ = _make_dummy_reward_model(
repo_id="user/my_reward",
license="apache-2.0",
)
# Point the reward model's train config at a dummy dataset repo.
train_cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/my_dataset"),
reward_model=model.config,
)
uploaded: dict = {}
fake_commit_info = SimpleNamespace(repo_url=SimpleNamespace(url="https://huggingface.co/user/my_reward"))
class _FakeHfApi:
def create_repo(self, repo_id, private=None, exist_ok=False):
uploaded["create_repo_id"] = repo_id
uploaded["create_private"] = private
return SimpleNamespace(repo_id=repo_id)
def upload_folder(self, *, repo_id, repo_type, folder_path, commit_message, **_kwargs):
uploaded["upload_repo_id"] = repo_id
uploaded["upload_repo_type"] = repo_type
uploaded["commit_message"] = commit_message
# Snapshot files assembled in the temp folder — this is the real
# contract we care about.
uploaded["files"] = sorted(p.name for p in Path(folder_path).iterdir())
return fake_commit_info
from lerobot.rewards import pretrained as reward_pretrained
monkeypatch.setattr(reward_pretrained, "HfApi", lambda *a, **kw: _FakeHfApi())
model.push_model_to_hub(train_cfg)
assert uploaded["create_repo_id"] == "user/my_reward"
assert uploaded["upload_repo_id"] == "user/my_reward"
assert uploaded["upload_repo_type"] == "model"
assert uploaded["commit_message"] == "Upload reward model weights, train config and readme"
# Minimum required files that must be uploaded with a reward model.
assert CONFIG_NAME in uploaded["files"] # config.json
assert TRAIN_CONFIG_NAME in uploaded["files"] # train_config.json
assert "README.md" in uploaded["files"]
assert any(name.endswith(".safetensors") for name in uploaded["files"])
@@ -104,8 +104,8 @@ class TestSARMEncodingProcessorStepEndToEnd:
def mock_clip_model(self):
"""Mock CLIP model to avoid loading real weights."""
with (
patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls,
patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
patch("lerobot.rewards.sarm.processor_sarm.CLIPModel") as mock_model_cls,
patch("lerobot.rewards.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
):
# Mock the CLIP model - return embeddings based on input batch size
mock_model = MagicMock()
@@ -142,7 +142,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
@pytest.fixture
def processor_with_mocks(self, mock_clip_model):
"""Create a processor with mocked CLIP and dataset metadata for dual mode."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
# Dual mode config with both sparse and dense annotations
config = MockConfig(
@@ -256,7 +256,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
def test_call_with_batched_input(self, mock_clip_model):
"""Test processor __call__ with a batched input (multiple frames) in dual mode."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
config = MockConfig(
n_obs_steps=8,
@@ -332,7 +332,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
def test_targets_increase_with_progress(self, mock_clip_model):
"""Test that both sparse and dense targets increase as frame index progresses."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
config = MockConfig(
n_obs_steps=8,
@@ -404,7 +404,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
def test_progress_labels_exact_values(self, mock_clip_model):
"""Test that progress labels (stage.tau) are computed correctly for known positions."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
# Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
config = MockConfig(
@@ -495,7 +495,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
"""Test that rewind augmentation correctly extends sequence and generates targets."""
import random
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
config = MockConfig(
n_obs_steps=8,
@@ -587,8 +587,8 @@ class TestSARMEncodingProcessorStepEndToEnd:
def test_full_sequence_target_consistency(self, mock_clip_model):
"""Test that the full sequence of targets is consistent with frame positions."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.policies.sarm.sarm_utils import find_stage_and_tau
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.sarm_utils import find_stage_and_tau
config = MockConfig(
n_obs_steps=8,
@@ -18,7 +18,7 @@ import numpy as np
import pytest
import torch
from lerobot.policies.sarm.sarm_utils import (
from lerobot.rewards.sarm.sarm_utils import (
apply_rewind_augmentation,
compute_absolute_indices,
compute_tau,
+401
View File
@@ -0,0 +1,401 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the sample weighting infrastructure."""
from unittest.mock import Mock
import pytest
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import torch
from lerobot.utils.sample_weighting import (
SampleWeighter,
SampleWeightingConfig,
UniformWeighter,
make_sample_weighter,
)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def sample_progress_parquet(tmp_path):
"""Create a sample progress parquet file for testing."""
import pandas as pd
# Create sample progress data for 2 episodes with 10 frames each
data = {
"index": list(range(20)),
"episode_index": [0] * 10 + [1] * 10,
"frame_index": list(range(10)) * 2,
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
}
df = pd.DataFrame(data)
parquet_path = tmp_path / "sarm_progress.parquet"
df.to_parquet(parquet_path)
return parquet_path
# =============================================================================
# SampleWeightingConfig Tests
# =============================================================================
def test_config_default_values():
"""Test default configuration values."""
config = SampleWeightingConfig()
assert config.type == "rabc"
assert config.progress_path is None
assert config.head_mode == "sparse"
assert config.kappa == 0.01
assert config.epsilon == 1e-6
assert config.extra_params == {}
def test_config_custom_values():
"""Test configuration with custom values."""
config = SampleWeightingConfig(
type="rabc",
progress_path="/path/to/progress.parquet",
head_mode="dense",
kappa=0.05,
epsilon=1e-8,
extra_params={"fallback_weight": 0.5},
)
assert config.type == "rabc"
assert config.progress_path == "/path/to/progress.parquet"
assert config.head_mode == "dense"
assert config.kappa == 0.05
assert config.epsilon == 1e-8
assert config.extra_params == {"fallback_weight": 0.5}
def test_config_uniform_type():
"""Test configuration for uniform weighting."""
config = SampleWeightingConfig(type="uniform")
assert config.type == "uniform"
# =============================================================================
# UniformWeighter Tests
# =============================================================================
def test_uniform_weighter_inherits_from_sample_weighter():
"""Test that UniformWeighter is a SampleWeighter."""
weighter = UniformWeighter(device=torch.device("cpu"))
assert isinstance(weighter, SampleWeighter)
def test_uniform_weighter_compute_batch_weights_with_action_key():
"""Test weight computation with 'action' key in batch."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"action": torch.randn(8, 10)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (8,)
assert torch.allclose(weights, torch.ones(8))
assert stats["mean_weight"] == 1.0
assert stats["type"] == "uniform"
def test_uniform_weighter_compute_batch_weights_with_index_key():
"""Test weight computation with 'index' key in batch."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"index": torch.arange(16)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (16,)
assert torch.allclose(weights, torch.ones(16))
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
"""Test weight computation with no tensor keys (fallback to size 1)."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"other_key": "some_value"}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (1,)
assert torch.allclose(weights, torch.ones(1))
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
"""Test that empty batch raises ValueError."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {}
with pytest.raises(ValueError, match="empty batch"):
weighter.compute_batch_weights(batch)
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
"""Test that batch size is determined by scanning all tensor values."""
weighter = UniformWeighter(device=torch.device("cpu"))
# Batch with non-standard key containing a tensor
batch = {"custom_tensor": torch.randn(7, 3)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (7,)
assert torch.allclose(weights, torch.ones(7))
def test_uniform_weighter_compute_batch_weights_on_cuda():
"""Test that weights are placed on the correct device."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
weighter = UniformWeighter(device=torch.device("cuda"))
batch = {"action": torch.randn(4, 10)}
weights, _ = weighter.compute_batch_weights(batch)
assert weights.device.type == "cuda"
def test_uniform_weighter_get_stats():
"""Test get_stats returns expected structure."""
weighter = UniformWeighter(device=torch.device("cpu"))
stats = weighter.get_stats()
assert stats == {"type": "uniform"}
# =============================================================================
# make_sample_weighter Factory Tests
# =============================================================================
def test_factory_returns_none_for_none_config():
"""Test that None config returns None weighter."""
policy = Mock()
device = torch.device("cpu")
result = make_sample_weighter(None, policy, device)
assert result is None
def test_factory_creates_uniform_weighter():
"""Test creation of UniformWeighter."""
config = SampleWeightingConfig(type="uniform")
policy = Mock()
device = torch.device("cpu")
weighter = make_sample_weighter(config, policy, device)
assert isinstance(weighter, UniformWeighter)
assert isinstance(weighter, SampleWeighter)
def test_factory_raises_for_unknown_type():
"""Test that unknown type raises ValueError."""
config = SampleWeightingConfig(type="unknown_type")
policy = Mock()
device = torch.device("cpu")
with pytest.raises(ValueError, match="Unknown sample weighting type"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_chunk_size():
"""Test that RABC weighter requires chunk_size in policy config."""
config = SampleWeightingConfig(
type="rabc",
progress_path="/path/to/progress.parquet",
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = None # No chunk_size
device = torch.device("cpu")
with pytest.raises(ValueError, match="chunk_size"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_progress_path_or_dataset_info():
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # No progress path
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 50
device = torch.device("cpu")
# Should fail when no progress_path AND no dataset info
with pytest.raises(ValueError, match="progress_path"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # Not provided, should auto-detect
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 5
device = torch.device("cpu")
# The parquet file is at sample_progress_parquet, get its parent directory
dataset_root = sample_progress_parquet.parent
weighter = make_sample_weighter(
config,
policy,
device,
dataset_root=str(dataset_root),
)
assert weighter is not None
from lerobot.rewards.sarm.rabc import RABCWeights
assert isinstance(weighter, RABCWeights)
def test_factory_rabc_auto_detects_from_repo_id():
"""Test that RABC weighter constructs HF path from repo_id."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # Not provided, should auto-detect
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 50
device = torch.device("cpu")
# This will construct the path but fail when trying to load (file doesn't exist)
# We just verify it doesn't raise the "progress_path required" error
with pytest.raises(Exception) as exc_info:
make_sample_weighter(
config,
policy,
device,
dataset_repo_id="test-user/test-dataset",
)
# Should NOT be the "progress_path required" error - it should try to load the file
assert (
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
)
# =============================================================================
# Integration Tests with RABCWeights
# =============================================================================
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
"""Test that RABCWeights inherits from SampleWeighter."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
)
assert isinstance(weighter, SampleWeighter)
def test_rabc_compute_batch_weights(sample_progress_parquet):
"""Test RABCWeights.compute_batch_weights returns correct structure."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
device=torch.device("cpu"),
)
batch = {"index": torch.tensor([0, 1, 2, 3])}
weights, stats = weighter.compute_batch_weights(batch)
assert isinstance(weights, torch.Tensor)
assert weights.shape == (4,)
assert isinstance(stats, dict)
assert "mean_weight" in stats
def test_rabc_get_stats(sample_progress_parquet):
"""Test RABCWeights.get_stats returns expected structure."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
)
stats = weighter.get_stats()
assert stats["type"] == "rabc"
assert "num_frames" in stats
assert "chunk_size" in stats
assert stats["chunk_size"] == 5
assert "head_mode" in stats
assert stats["head_mode"] == "sparse"
assert "delta_mean" in stats
assert "delta_std" in stats
def test_factory_creates_rabc_weighter(sample_progress_parquet):
"""Test factory creates RABCWeights with valid config."""
from lerobot.rewards.sarm.rabc import RABCWeights
config = SampleWeightingConfig(
type="rabc",
progress_path=str(sample_progress_parquet),
head_mode="sparse",
kappa=0.01,
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 5
device = torch.device("cpu")
weighter = make_sample_weighter(config, policy, device)
assert isinstance(weighter, RABCWeights)
assert isinstance(weighter, SampleWeighter)
def test_rabc_weights_normalization(sample_progress_parquet):
"""Test that RABCWeights normalizes weights to sum to batch_size."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
device=torch.device("cpu"),
)
batch = {"index": torch.tensor([0, 1, 2, 3])}
weights, _ = weighter.compute_batch_weights(batch)
# Weights should be normalized to sum approximately to batch_size
batch_size = 4
assert abs(weights.sum().item() - batch_size) < 0.1