mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
Reward models refactor (#3142)
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes * refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/ * refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/ * refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py * refactor(rewards): update imports and delete old reward model locations * test(rewards): add reward model tests and update existing test imports * fix(rewards): restore full Classifier and SARM implementations * test(rewards): restore missing CUDA and mixed precision classifier processor tests * refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train * refactor(lerobot_train.py): add missing sampling weight script * linter + missing files * add testing for sampl weighter * revert some useless changes, improve typing * update docs * add automatic detection of the progress path * remove type exp * improve comment * fix: move rabc.py to rewards/sarm/ and update import paths * refactor(imports): update reward model imports to new module structure * refactor(imports): update reward model imports to reflect new module structure * refactor(imports): conditionally import pandas based on availability * feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig * refactor(policies): remove reward model branches from policy factory and __init__ * refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash * feat(train): route reward model training through rewards/factory instead of policies/factory * refactor(train): streamline reward model training logic * fix(rewards): ensure FileNotFoundError is raised for missing config_file * refactor(train): update __get_path_fields__ to include reward_model for config loading * refactor(classifier): remove redundant input normalization in predict_reward method * fix(train): raise ValueError for non-trainable reward models in train function * refactor(pretrained_rm): add model card template * refactor(tests): reward models * refactor(sarm): update reset method and remove unused action prediction methods * refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function * fix(train): raise ValueError for PEFT usage in reward model training * refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
+10
-34
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -21,8 +19,6 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
@@ -31,6 +27,8 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.classifier.processor_classifier import make_classifier_processor
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
@@ -42,7 +40,7 @@ def create_default_config():
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), # Classifier output
|
||||
"reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
@@ -90,17 +88,14 @@ def test_classifier_processor_normalization():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1) # Dummy action/reward
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
@@ -120,10 +115,7 @@ def test_classifier_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -132,7 +124,6 @@ def test_classifier_processor_cuda():
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
@@ -158,10 +149,7 @@ def test_classifier_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -171,7 +159,6 @@ def test_classifier_processor_accelerate_scenario():
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
@@ -201,7 +188,6 @@ def test_classifier_processor_multi_gpu():
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
@@ -231,7 +217,6 @@ def test_classifier_processor_without_stats():
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
@@ -294,7 +279,6 @@ def test_classifier_processor_mixed_precision():
|
||||
}
|
||||
action = torch.randn(1, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
@@ -312,10 +296,7 @@ def test_classifier_processor_batch_data():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 16
|
||||
@@ -325,7 +306,6 @@ def test_classifier_processor_batch_data():
|
||||
}
|
||||
action = torch.randn(batch_size, 1)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
@@ -343,15 +323,11 @@ def test_classifier_processor_postprocessor_identity():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create test data for postprocessor
|
||||
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
|
||||
reward = torch.tensor([[0.8], [0.3], [0.9]])
|
||||
transition = create_transition(action=reward)
|
||||
|
||||
_ = transition_to_batch(transition)
|
||||
|
||||
# Process through postprocessor
|
||||
+15
-9
@@ -1,5 +1,3 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -18,8 +16,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.classifier.modeling_classifier import ClassifierOutput
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
@@ -42,7 +40,7 @@ def test_classifier_output():
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
@@ -86,7 +84,7 @@ def test_binary_classifier_with_default_params():
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = RewardClassifierConfig()
|
||||
@@ -128,11 +126,15 @@ def test_multiclass_classifier():
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_default_device():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
assert config.device is None or config.device == "cpu"
|
||||
|
||||
config.input_features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.num_cameras = 1
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
@@ -143,11 +145,15 @@ def test_default_device():
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig(device="cpu")
|
||||
assert config.device == "cpu"
|
||||
|
||||
config.input_features = {
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.num_cameras = 1
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
@@ -0,0 +1,373 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the reward model base classes and registry."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass(name="_dummy_hub_reward")
|
||||
@dataclass
|
||||
class _DummyHubRewardConfig(RewardModelConfig):
|
||||
def get_optimizer_preset(self):
|
||||
return AdamWConfig(lr=1e-4)
|
||||
|
||||
|
||||
class _DummyHubReward(PreTrainedRewardModel):
|
||||
config_class = _DummyHubRewardConfig
|
||||
name = "_dummy_hub_reward"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.bias = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
def compute_reward(self, batch):
|
||||
return self.bias.expand(1)
|
||||
|
||||
|
||||
def test_reward_model_config_registry():
|
||||
"""Verify that classifier and sarm are registered."""
|
||||
known = RewardModelConfig.get_known_choices()
|
||||
assert "reward_classifier" in known
|
||||
assert "sarm" in known
|
||||
|
||||
|
||||
def test_reward_model_config_lookup():
|
||||
"""Verify that we can look up configs by name."""
|
||||
cls = RewardModelConfig.get_choice_class("reward_classifier")
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
assert cls is RewardClassifierConfig
|
||||
|
||||
|
||||
def test_factory_get_reward_model_class():
|
||||
"""Test the get_reward_model_class factory."""
|
||||
from lerobot.rewards.factory import get_reward_model_class
|
||||
|
||||
cls = get_reward_model_class("sarm")
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
assert cls is SARMRewardModel
|
||||
|
||||
|
||||
def test_factory_unknown_raises():
|
||||
"""Unknown name should raise ValueError."""
|
||||
from lerobot.rewards.factory import get_reward_model_class
|
||||
|
||||
with pytest.raises(ValueError, match="not available"):
|
||||
get_reward_model_class("nonexistent_reward_model")
|
||||
|
||||
|
||||
def test_pretrained_reward_model_requires_config_class():
|
||||
"""Subclass without config_class should fail."""
|
||||
with pytest.raises(TypeError, match="must define 'config_class'"):
|
||||
|
||||
class BadModel(PreTrainedRewardModel):
|
||||
name = "bad"
|
||||
|
||||
def compute_reward(self, batch):
|
||||
pass
|
||||
|
||||
|
||||
def test_pretrained_reward_model_requires_name():
|
||||
"""Subclass without name should fail."""
|
||||
with pytest.raises(TypeError, match="must define 'name'"):
|
||||
|
||||
class BadModel(PreTrainedRewardModel):
|
||||
config_class = RewardModelConfig
|
||||
|
||||
def compute_reward(self, batch):
|
||||
pass
|
||||
|
||||
|
||||
def test_non_trainable_forward_raises():
|
||||
"""Non-trainable model should raise on forward()."""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
|
||||
@dataclass
|
||||
class DummyConfig(RewardModelConfig):
|
||||
def get_optimizer_preset(self):
|
||||
return AdamWConfig(lr=1e-4)
|
||||
|
||||
class DummyReward(PreTrainedRewardModel):
|
||||
config_class = DummyConfig
|
||||
name = "dummy_test"
|
||||
|
||||
def compute_reward(self, batch):
|
||||
return torch.zeros(1)
|
||||
|
||||
config = DummyConfig()
|
||||
model = DummyReward(config)
|
||||
|
||||
with pytest.raises(NotImplementedError, match="not trainable"):
|
||||
model.forward({"x": torch.zeros(1)})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trainable vs zero-shot (general-purpose) reward models.
|
||||
# The proposal explicitly supports models like TOPReward that wrap a pretrained
|
||||
# VLM and produce a reward signal without any training step. These tests pin
|
||||
# the contract that lets such models coexist with trainable ones.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_trainable_false_when_forward_not_overridden():
|
||||
"""A reward model that only implements ``compute_reward`` is zero-shot."""
|
||||
model, _ = _make_dummy_reward_model()
|
||||
assert model.is_trainable is False
|
||||
|
||||
|
||||
def test_is_trainable_true_when_forward_overridden():
|
||||
"""Overriding ``forward`` flips ``is_trainable`` to True."""
|
||||
|
||||
class _TrainableReward(_DummyHubReward):
|
||||
name = "_trainable_dummy_reward"
|
||||
|
||||
def forward(self, batch):
|
||||
loss = (self.bias**2).sum()
|
||||
return loss, {}
|
||||
|
||||
# Register a fresh config subclass so the subclass check passes.
|
||||
@RewardModelConfig.register_subclass(name="_trainable_dummy_reward")
|
||||
@dataclass
|
||||
class _TrainableConfig(_DummyHubRewardConfig):
|
||||
pass
|
||||
|
||||
_TrainableReward.config_class = _TrainableConfig
|
||||
model = _TrainableReward(_TrainableConfig())
|
||||
assert model.is_trainable is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RewardModelConfig.from_pretrained
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reward_model_config_from_pretrained_raises_when_config_missing(tmp_path):
|
||||
"""``from_pretrained`` must surface a clear ``FileNotFoundError`` when the
|
||||
target directory exists but does not contain ``config.json``, instead of
|
||||
crashing later inside ``draccus.parse``.
|
||||
"""
|
||||
# tmp_path exists but has no config.json
|
||||
with pytest.raises(FileNotFoundError, match="config.json not found"):
|
||||
RewardModelConfig.from_pretrained(tmp_path)
|
||||
|
||||
|
||||
def test_reward_model_config_from_pretrained_roundtrip(tmp_path):
|
||||
"""Round-trip: save a RewardClassifierConfig, reload it, fields must match."""
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
original = RewardClassifierConfig(
|
||||
num_classes=3,
|
||||
hidden_dim=128,
|
||||
latent_dim=64,
|
||||
num_cameras=1,
|
||||
learning_rate=5e-4,
|
||||
)
|
||||
original._save_pretrained(tmp_path)
|
||||
|
||||
loaded = RewardModelConfig.from_pretrained(tmp_path)
|
||||
|
||||
assert isinstance(loaded, RewardClassifierConfig)
|
||||
assert loaded.num_classes == 3
|
||||
assert loaded.hidden_dim == 128
|
||||
assert loaded.latent_dim == 64
|
||||
assert loaded.num_cameras == 1
|
||||
assert loaded.learning_rate == 5e-4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TrainPipelineConfig — reward model training path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_train_pipeline_config_path_fields_includes_reward_model():
|
||||
"""``--reward_model.path=local/dir`` requires ``reward_model`` to be listed
|
||||
as a draccus path-field on ``TrainPipelineConfig``."""
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
fields = TrainPipelineConfig.__get_path_fields__()
|
||||
assert "policy" in fields
|
||||
assert "reward_model" in fields
|
||||
|
||||
|
||||
def test_train_pipeline_config_trainable_config_returns_reward_model_when_set():
|
||||
"""When only ``reward_model`` is set, ``trainable_config`` (used by the
|
||||
trainer for e.g. ``.device``) must return it — not ``None`` from ``policy``."""
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
reward_cfg = RewardClassifierConfig(device="cpu")
|
||||
cfg = TrainPipelineConfig(
|
||||
dataset=DatasetConfig(repo_id="user/repo"),
|
||||
reward_model=reward_cfg,
|
||||
)
|
||||
|
||||
assert cfg.is_reward_model_training is True
|
||||
assert cfg.trainable_config is reward_cfg
|
||||
# This is what lerobot_train.py uses to decide force_cpu; ``cfg.policy.device``
|
||||
# would AttributeError here because policy is None.
|
||||
assert cfg.trainable_config.device == "cpu"
|
||||
|
||||
|
||||
def test_train_pipeline_config_trainable_config_returns_policy_when_set():
|
||||
"""Mirror of the reward-model case: when only ``policy`` is set,
|
||||
``trainable_config`` must return it."""
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
|
||||
policy_cfg = DiffusionConfig(device="cpu")
|
||||
cfg = TrainPipelineConfig(
|
||||
dataset=DatasetConfig(repo_id="user/repo"),
|
||||
policy=policy_cfg,
|
||||
)
|
||||
|
||||
assert cfg.is_reward_model_training is False
|
||||
assert cfg.trainable_config is policy_cfg
|
||||
assert cfg.trainable_config.device == "cpu"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
|
||||
# We test the generation side (offline) fully, and the upload side with HfApi
|
||||
# mocked so nothing actually hits the network.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dummy_reward_model(**config_kwargs):
|
||||
return _DummyHubReward(_DummyHubRewardConfig(**config_kwargs)), _DummyHubRewardConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _offline_model_card(monkeypatch):
|
||||
"""``ModelCard.validate`` does a live ``POST`` to huggingface.co — bypass it
|
||||
so tests can run offline."""
|
||||
from huggingface_hub import ModelCard
|
||||
|
||||
monkeypatch.setattr(ModelCard, "validate", lambda self, *a, **kw: None)
|
||||
|
||||
|
||||
def test_reward_model_generate_model_card_renders_expected_fields(_offline_model_card):
|
||||
"""``generate_model_card`` must produce a card with the right metadata and
|
||||
body, using the dedicated reward-model template."""
|
||||
model, _ = _make_dummy_reward_model(
|
||||
license="mit",
|
||||
tags=["robot", "sim"],
|
||||
)
|
||||
|
||||
card = model.generate_model_card(
|
||||
dataset_repo_id="user/my_dataset",
|
||||
model_type=model.config.type,
|
||||
license=model.config.license,
|
||||
tags=model.config.tags,
|
||||
)
|
||||
|
||||
# Metadata (YAML header) — ModelCardData fields.
|
||||
assert card.data.license == "mit"
|
||||
assert card.data.library_name == "lerobot"
|
||||
assert card.data.pipeline_tag == "robotics"
|
||||
assert "reward-model" in card.data.tags
|
||||
assert model.config.type in card.data.tags
|
||||
assert card.data.model_name == model.config.type
|
||||
assert card.data.datasets == "user/my_dataset"
|
||||
|
||||
# Body — specific to the reward-model template, NOT the policy one.
|
||||
body = str(card)
|
||||
assert "Reward Model Card" in body
|
||||
assert "This reward model has been trained" in body
|
||||
assert "--reward_model.type=" in body # reward-model-specific usage block
|
||||
|
||||
|
||||
def test_reward_model_generate_model_card_uses_default_license(_offline_model_card):
|
||||
"""When config.license is None the card falls back to apache-2.0."""
|
||||
model, _ = _make_dummy_reward_model()
|
||||
|
||||
card = model.generate_model_card(
|
||||
dataset_repo_id="user/my_dataset",
|
||||
model_type=model.config.type,
|
||||
license=model.config.license,
|
||||
tags=None,
|
||||
)
|
||||
|
||||
assert card.data.license == "apache-2.0"
|
||||
|
||||
|
||||
def test_reward_model_push_model_to_hub_uploads_expected_files(monkeypatch, _offline_model_card):
|
||||
"""``push_model_to_hub`` must:
|
||||
1. create the repo,
|
||||
2. assemble a temp folder with weights + config.json + train_config.json + README.md,
|
||||
3. call ``api.upload_folder`` on that folder.
|
||||
All network calls are mocked.
|
||||
"""
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||
|
||||
model, _ = _make_dummy_reward_model(
|
||||
repo_id="user/my_reward",
|
||||
license="apache-2.0",
|
||||
)
|
||||
# Point the reward model's train config at a dummy dataset repo.
|
||||
train_cfg = TrainPipelineConfig(
|
||||
dataset=DatasetConfig(repo_id="user/my_dataset"),
|
||||
reward_model=model.config,
|
||||
)
|
||||
|
||||
uploaded: dict = {}
|
||||
fake_commit_info = SimpleNamespace(repo_url=SimpleNamespace(url="https://huggingface.co/user/my_reward"))
|
||||
|
||||
class _FakeHfApi:
|
||||
def create_repo(self, repo_id, private=None, exist_ok=False):
|
||||
uploaded["create_repo_id"] = repo_id
|
||||
uploaded["create_private"] = private
|
||||
return SimpleNamespace(repo_id=repo_id)
|
||||
|
||||
def upload_folder(self, *, repo_id, repo_type, folder_path, commit_message, **_kwargs):
|
||||
uploaded["upload_repo_id"] = repo_id
|
||||
uploaded["upload_repo_type"] = repo_type
|
||||
uploaded["commit_message"] = commit_message
|
||||
# Snapshot files assembled in the temp folder — this is the real
|
||||
# contract we care about.
|
||||
uploaded["files"] = sorted(p.name for p in Path(folder_path).iterdir())
|
||||
return fake_commit_info
|
||||
|
||||
from lerobot.rewards import pretrained as reward_pretrained
|
||||
|
||||
monkeypatch.setattr(reward_pretrained, "HfApi", lambda *a, **kw: _FakeHfApi())
|
||||
|
||||
model.push_model_to_hub(train_cfg)
|
||||
|
||||
assert uploaded["create_repo_id"] == "user/my_reward"
|
||||
assert uploaded["upload_repo_id"] == "user/my_reward"
|
||||
assert uploaded["upload_repo_type"] == "model"
|
||||
assert uploaded["commit_message"] == "Upload reward model weights, train config and readme"
|
||||
# Minimum required files that must be uploaded with a reward model.
|
||||
assert CONFIG_NAME in uploaded["files"] # config.json
|
||||
assert TRAIN_CONFIG_NAME in uploaded["files"] # train_config.json
|
||||
assert "README.md" in uploaded["files"]
|
||||
assert any(name.endswith(".safetensors") for name in uploaded["files"])
|
||||
@@ -104,8 +104,8 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
||||
def mock_clip_model(self):
|
||||
"""Mock CLIP model to avoid loading real weights."""
|
||||
with (
|
||||
patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls,
|
||||
patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
|
||||
patch("lerobot.rewards.sarm.processor_sarm.CLIPModel") as mock_model_cls,
|
||||
patch("lerobot.rewards.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
|
||||
):
|
||||
# Mock the CLIP model - return embeddings based on input batch size
|
||||
mock_model = MagicMock()
|
||||
@@ -142,7 +142,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
||||
@pytest.fixture
|
||||
def processor_with_mocks(self, mock_clip_model):
|
||||
"""Create a processor with mocked CLIP and dataset metadata for dual mode."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
# Dual mode config with both sparse and dense annotations
|
||||
config = MockConfig(
|
||||
@@ -256,7 +256,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
||||
|
||||
def test_call_with_batched_input(self, mock_clip_model):
|
||||
"""Test processor __call__ with a batched input (multiple frames) in dual mode."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
@@ -332,7 +332,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
||||
|
||||
def test_targets_increase_with_progress(self, mock_clip_model):
|
||||
"""Test that both sparse and dense targets increase as frame index progresses."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
@@ -404,7 +404,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
||||
|
||||
def test_progress_labels_exact_values(self, mock_clip_model):
|
||||
"""Test that progress labels (stage.tau) are computed correctly for known positions."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
# Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
|
||||
config = MockConfig(
|
||||
@@ -495,7 +495,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
||||
"""Test that rewind augmentation correctly extends sequence and generates targets."""
|
||||
import random
|
||||
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
@@ -587,8 +587,8 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
||||
|
||||
def test_full_sequence_target_consistency(self, mock_clip_model):
|
||||
"""Test that the full sequence of targets is consistent with frame positions."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
from lerobot.policies.sarm.sarm_utils import find_stage_and_tau
|
||||
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
from lerobot.rewards.sarm.sarm_utils import find_stage_and_tau
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
@@ -18,7 +18,7 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
from lerobot.rewards.sarm.sarm_utils import (
|
||||
apply_rewind_augmentation,
|
||||
compute_absolute_indices,
|
||||
compute_tau,
|
||||
@@ -0,0 +1,401 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the sample weighting infrastructure."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.utils.sample_weighting import (
|
||||
SampleWeighter,
|
||||
SampleWeightingConfig,
|
||||
UniformWeighter,
|
||||
make_sample_weighter,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_progress_parquet(tmp_path):
|
||||
"""Create a sample progress parquet file for testing."""
|
||||
import pandas as pd
|
||||
|
||||
# Create sample progress data for 2 episodes with 10 frames each
|
||||
data = {
|
||||
"index": list(range(20)),
|
||||
"episode_index": [0] * 10 + [1] * 10,
|
||||
"frame_index": list(range(10)) * 2,
|
||||
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
|
||||
}
|
||||
df = pd.DataFrame(data)
|
||||
parquet_path = tmp_path / "sarm_progress.parquet"
|
||||
df.to_parquet(parquet_path)
|
||||
return parquet_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SampleWeightingConfig Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_config_default_values():
|
||||
"""Test default configuration values."""
|
||||
config = SampleWeightingConfig()
|
||||
assert config.type == "rabc"
|
||||
assert config.progress_path is None
|
||||
assert config.head_mode == "sparse"
|
||||
assert config.kappa == 0.01
|
||||
assert config.epsilon == 1e-6
|
||||
assert config.extra_params == {}
|
||||
|
||||
|
||||
def test_config_custom_values():
|
||||
"""Test configuration with custom values."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path="/path/to/progress.parquet",
|
||||
head_mode="dense",
|
||||
kappa=0.05,
|
||||
epsilon=1e-8,
|
||||
extra_params={"fallback_weight": 0.5},
|
||||
)
|
||||
assert config.type == "rabc"
|
||||
assert config.progress_path == "/path/to/progress.parquet"
|
||||
assert config.head_mode == "dense"
|
||||
assert config.kappa == 0.05
|
||||
assert config.epsilon == 1e-8
|
||||
assert config.extra_params == {"fallback_weight": 0.5}
|
||||
|
||||
|
||||
def test_config_uniform_type():
|
||||
"""Test configuration for uniform weighting."""
|
||||
config = SampleWeightingConfig(type="uniform")
|
||||
assert config.type == "uniform"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# UniformWeighter Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_uniform_weighter_inherits_from_sample_weighter():
|
||||
"""Test that UniformWeighter is a SampleWeighter."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_with_action_key():
|
||||
"""Test weight computation with 'action' key in batch."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"action": torch.randn(8, 10)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (8,)
|
||||
assert torch.allclose(weights, torch.ones(8))
|
||||
assert stats["mean_weight"] == 1.0
|
||||
assert stats["type"] == "uniform"
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_with_index_key():
|
||||
"""Test weight computation with 'index' key in batch."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"index": torch.arange(16)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (16,)
|
||||
assert torch.allclose(weights, torch.ones(16))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
|
||||
"""Test weight computation with no tensor keys (fallback to size 1)."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"other_key": "some_value"}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (1,)
|
||||
assert torch.allclose(weights, torch.ones(1))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
|
||||
"""Test that empty batch raises ValueError."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {}
|
||||
|
||||
with pytest.raises(ValueError, match="empty batch"):
|
||||
weighter.compute_batch_weights(batch)
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
|
||||
"""Test that batch size is determined by scanning all tensor values."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
# Batch with non-standard key containing a tensor
|
||||
batch = {"custom_tensor": torch.randn(7, 3)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (7,)
|
||||
assert torch.allclose(weights, torch.ones(7))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_on_cuda():
|
||||
"""Test that weights are placed on the correct device."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
weighter = UniformWeighter(device=torch.device("cuda"))
|
||||
batch = {"action": torch.randn(4, 10)}
|
||||
|
||||
weights, _ = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.device.type == "cuda"
|
||||
|
||||
|
||||
def test_uniform_weighter_get_stats():
|
||||
"""Test get_stats returns expected structure."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
stats = weighter.get_stats()
|
||||
|
||||
assert stats == {"type": "uniform"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# make_sample_weighter Factory Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_factory_returns_none_for_none_config():
|
||||
"""Test that None config returns None weighter."""
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
result = make_sample_weighter(None, policy, device)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_factory_creates_uniform_weighter():
|
||||
"""Test creation of UniformWeighter."""
|
||||
config = SampleWeightingConfig(type="uniform")
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
weighter = make_sample_weighter(config, policy, device)
|
||||
|
||||
assert isinstance(weighter, UniformWeighter)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_factory_raises_for_unknown_type():
|
||||
"""Test that unknown type raises ValueError."""
|
||||
config = SampleWeightingConfig(type="unknown_type")
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown sample weighting type"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_chunk_size():
|
||||
"""Test that RABC weighter requires chunk_size in policy config."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path="/path/to/progress.parquet",
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = None # No chunk_size
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="chunk_size"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_progress_path_or_dataset_info():
|
||||
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # No progress path
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Should fail when no progress_path AND no dataset info
|
||||
with pytest.raises(ValueError, match="progress_path"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
|
||||
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # Not provided, should auto-detect
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 5
|
||||
device = torch.device("cpu")
|
||||
|
||||
# The parquet file is at sample_progress_parquet, get its parent directory
|
||||
dataset_root = sample_progress_parquet.parent
|
||||
weighter = make_sample_weighter(
|
||||
config,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=str(dataset_root),
|
||||
)
|
||||
|
||||
assert weighter is not None
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
assert isinstance(weighter, RABCWeights)
|
||||
|
||||
|
||||
def test_factory_rabc_auto_detects_from_repo_id():
|
||||
"""Test that RABC weighter constructs HF path from repo_id."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # Not provided, should auto-detect
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
# This will construct the path but fail when trying to load (file doesn't exist)
|
||||
# We just verify it doesn't raise the "progress_path required" error
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
make_sample_weighter(
|
||||
config,
|
||||
policy,
|
||||
device,
|
||||
dataset_repo_id="test-user/test-dataset",
|
||||
)
|
||||
# Should NOT be the "progress_path required" error - it should try to load the file
|
||||
assert (
|
||||
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests with RABCWeights
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
|
||||
"""Test that RABCWeights inherits from SampleWeighter."""
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_rabc_compute_batch_weights(sample_progress_parquet):
|
||||
"""Test RABCWeights.compute_batch_weights returns correct structure."""
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert isinstance(weights, torch.Tensor)
|
||||
assert weights.shape == (4,)
|
||||
assert isinstance(stats, dict)
|
||||
assert "mean_weight" in stats
|
||||
|
||||
|
||||
def test_rabc_get_stats(sample_progress_parquet):
|
||||
"""Test RABCWeights.get_stats returns expected structure."""
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
)
|
||||
|
||||
stats = weighter.get_stats()
|
||||
|
||||
assert stats["type"] == "rabc"
|
||||
assert "num_frames" in stats
|
||||
assert "chunk_size" in stats
|
||||
assert stats["chunk_size"] == 5
|
||||
assert "head_mode" in stats
|
||||
assert stats["head_mode"] == "sparse"
|
||||
assert "delta_mean" in stats
|
||||
assert "delta_std" in stats
|
||||
|
||||
|
||||
def test_factory_creates_rabc_weighter(sample_progress_parquet):
|
||||
"""Test factory creates RABCWeights with valid config."""
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=str(sample_progress_parquet),
|
||||
head_mode="sparse",
|
||||
kappa=0.01,
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 5
|
||||
device = torch.device("cpu")
|
||||
|
||||
weighter = make_sample_weighter(config, policy, device)
|
||||
|
||||
assert isinstance(weighter, RABCWeights)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_rabc_weights_normalization(sample_progress_parquet):
|
||||
"""Test that RABCWeights normalizes weights to sum to batch_size."""
|
||||
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||
weights, _ = weighter.compute_batch_weights(batch)
|
||||
|
||||
# Weights should be normalized to sum approximately to batch_size
|
||||
batch_size = 4
|
||||
assert abs(weights.sum().item() - batch_size) < 0.1
|
||||
Reference in New Issue
Block a user