Reward models refactor (#3142)

* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes

* refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/

* refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/

* refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py

* refactor(rewards): update imports and delete old reward model locations

* test(rewards): add reward model tests and update existing test imports

* fix(rewards): restore full Classifier and SARM implementations

* test(rewards): restore missing CUDA and mixed precision classifier processor tests

* refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train

* refactor(lerobot_train.py): add missing sampling weight script

* linter + missing files

* add testing for sampl weighter

* revert some useless changes, improve typing

* update docs

* add automatic detection of the progress path

* remove type exp

* improve comment

* fix: move rabc.py to rewards/sarm/ and update import paths

* refactor(imports): update reward model imports to new module structure

* refactor(imports): update reward model imports to reflect new module structure

* refactor(imports): conditionally import pandas based on availability

* feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig

* refactor(policies): remove reward model branches from policy factory and __init__

* refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash

* feat(train): route reward model training through rewards/factory instead of policies/factory

* refactor(train): streamline reward model training logic

* fix(rewards): ensure FileNotFoundError is raised for missing config_file

* refactor(train): update __get_path_fields__ to include reward_model for config loading

* refactor(classifier): remove redundant input normalization in predict_reward method

* fix(train): raise ValueError for non-trainable reward models in train function

* refactor(pretrained_rm): add model card template

* refactor(tests): reward models

* refactor(sarm): update reset method and remove unused action prediction methods

* refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function

* fix(train): raise ValueError for PEFT usage in reward model training

* refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Khalil Meftah
2026-04-28 17:56:24 +02:00
committed by GitHub
parent 03ee50e08f
commit 8a3d64033f
37 changed files with 2091 additions and 381 deletions
+29 -28
View File
@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
## Inputs and Targets (What the new code expects)
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
<hfoption id="single_stage">
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--visualize-only \
@@ -360,7 +360,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
<hfoption id="dense_only">
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--visualize-only \
@@ -373,7 +373,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
<hfoption id="dual">
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--visualize-only \
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
First, run the SARM model on all frames in your dataset to compute progress values:
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--head-mode sparse \
@@ -465,15 +465,15 @@ This script:
### Step 5b: Train Policy with RA-BC
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash
lerobot-train \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
--rabc_head_mode=sparse \
--rabc_kappa=0.01 \
--sample_weighting.type=rabc \
--sample_weighting.head_mode=sparse \
--sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
--batch_size=32 \
--steps=40000
@@ -488,12 +488,13 @@ The training script automatically:
**RA-BC Arguments:**
| Argument | Description | Default |
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
| Argument | Description | Default |
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
### Tuning RA-BC Kappa
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
Monitor these WandB metrics during training:
| Metric | Healthy Range | Problem Indicator |
| ------------------ | ------------- | ------------------------- |
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `rabc_delta_mean` | > 0 | Should be positive |
| `rabc_delta_std` | > 0 | Variance in data quality |
| Metric | Healthy Range | Problem Indicator |
| ----------------------------- | ------------- | ------------------------- |
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `sample_weighting/delta_mean` | > 0 | Should be positive |
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
**Setting kappa based on your data:**
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
```
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
# Most deltas fall in range [0.01, 0.05]
# Option 1: Set kappa = delta_mean (medium selectivity)
--rabc_kappa=0.03
--sample_weighting.kappa=0.03
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
--rabc_kappa=0.05
--sample_weighting.kappa=0.05
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
--rabc_kappa=0.07
--sample_weighting.kappa=0.07
```
**When RA-BC may not help:**
@@ -550,8 +551,8 @@ accelerate launch \
src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
--rabc_kappa=0.01 \
--sample_weighting.type=rabc \
--sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
--batch_size=32 \
--steps=40000
@@ -576,7 +577,7 @@ accelerate launch \
### RA-BC
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
---
+1 -1
View File
@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
import torch
from tqdm import tqdm
from lerobot.policies.sarm.compute_rabc_weights import (
from lerobot.rewards.sarm.compute_rabc_weights import (
generate_all_frame_indices,
interpolate_progress,
load_sarm_resources,
+1 -1
View File
@@ -10,7 +10,7 @@ from lerobot.datasets import LeRobotDataset
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
@@ -1,7 +1,7 @@
import torch
from lerobot.datasets import LeRobotDataset
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
def main():
@@ -22,10 +22,10 @@ def main():
model_name="microsoft/resnet-18",
)
# Make policy, preprocessor, and optimizer
policy = make_policy(config, ds_meta=dataset.meta)
optimizer = config.get_optimizer_preset().build(policy.parameters())
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
# Make reward model, preprocessor, and optimizer
reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
optimizer = config.get_optimizer_preset().build(reward_model.parameters())
preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
classifier_id = "<user>/reward_classifier_hil_serl_example"
@@ -42,7 +42,7 @@ def main():
batch = preprocessor(batch)
# Forward pass
loss, output_dict = policy.forward(batch)
loss, output_dict = reward_model.forward(batch)
# Backward pass and optimization
optimizer.zero_grad()
@@ -58,8 +58,8 @@ def main():
print("Training finished!")
# You can now save the trained policy.
policy.push_to_hub(classifier_id)
# You can now save the trained reward model.
reward_model.push_to_hub(classifier_id)
if __name__ == "__main__":
+5 -1
View File
@@ -41,8 +41,12 @@ def cfg_to_group(
return tag
return tag[:max_tag_length]
if cfg.is_reward_model_training:
trainable_tag = f"reward_model:{cfg.reward_model.type}"
else:
trainable_tag = f"policy:{cfg.policy.type}"
lst = [
f"policy:{cfg.policy.type}",
trainable_tag,
f"seed:{cfg.seed}",
]
if cfg.dataset is not None:
+163
View File
@@ -0,0 +1,163 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import builtins
import json
import logging
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import PolicyFeature
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="RewardModelConfig")
logger = logging.getLogger(__name__)
@dataclass
class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""Base configuration for reward models.
Args:
input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
"""
# Reuses PolicyFeature
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None
pretrained_path: str | None = None
push_to_hub: bool = False
repo_id: str | None = None
# Hub metadata
license: str | None = None
tags: list[str] | None = None
private: bool | None = None
def __post_init__(self) -> None:
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
@property
def type(self) -> str:
choice_name = self.get_choice_name(self.__class__)
if not isinstance(choice_name, str):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
@property
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
return None
@property
def action_delta_indices(self) -> list | None: # type: ignore[type-arg]
return None
@property
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
return None
@abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return None
def validate_features(self) -> None:
pass
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**reward_kwargs: Any,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
# HACK: Parse the original config to get the config subclass, so that we can
# apply cli overrides.
with draccus.config_type("json"):
orig_config = draccus.parse(cls, config_file, args=[])
with open(config_file) as f:
config = json.load(f)
config.pop("type", None)
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
json.dump(config, f)
config_file = f.name
cli_overrides = reward_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
+40 -29
View File
@@ -26,9 +26,11 @@ from lerobot import envs
from lerobot.configs import parser
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .rewards import RewardModelConfig
TRAIN_CONFIG_NAME = "train_config.json"
@@ -38,6 +40,7 @@ class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
reward_model: RewardModelConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None
@@ -72,27 +75,41 @@ class TrainPipelineConfig(HubMixin):
wandb: WandBConfig = field(default_factory=WandBConfig)
peft: PeftConfig | None = None
# RA-BC (Reward-Aligned Behavior Cloning) parameters
use_rabc: bool = False # Enable reward-weighted training
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
# Sample weighting configuration (e.g., for RA-BC training)
sample_weighting: SampleWeightingConfig | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
checkpoint_path: Path | None = field(init=False, default=None)
@property
def is_reward_model_training(self) -> bool:
"""True when the config targets a reward model rather than a policy."""
return self.reward_model is not None
@property
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
"""Return whichever config (policy or reward_model) is active."""
if self.is_reward_model_training:
return self.reward_model # type: ignore[return-value]
return self.policy # type: ignore[return-value]
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
# Only load the policy config
reward_model_path = parser.get_path_arg("reward_model")
if reward_model_path:
cli_overrides = parser.get_cli_overrides("reward_model")
self.reward_model = RewardModelConfig.from_pretrained(
reward_model_path, cli_overrides=cli_overrides
)
self.reward_model.pretrained_path = str(Path(reward_model_path))
elif policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = Path(policy_path)
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path")
if not config_path:
raise ValueError(
@@ -108,18 +125,22 @@ class TrainPipelineConfig(HubMixin):
policy_dir = Path(config_path).parent
if self.policy is not None:
self.policy.pretrained_path = policy_dir
if self.reward_model is not None:
self.reward_model.pretrained_path = str(policy_dir)
self.checkpoint_path = policy_dir.parent
if self.policy is None:
if self.policy is None and self.reward_model is None:
raise ValueError(
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
"Neither policy nor reward_model is configured. "
"Please specify one with `--policy.path` or `--reward_model.path`."
)
active_cfg = self.trainable_config
if not self.job_name:
if self.env is None:
self.job_name = f"{self.policy.type}"
self.job_name = f"{active_cfg.type}"
else:
self.job_name = f"{self.env.type}_{self.policy.type}"
self.job_name = f"{self.env.type}_{active_cfg.type}"
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
raise FileExistsError(
@@ -137,26 +158,16 @@ class TrainPipelineConfig(HubMixin):
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
elif self.use_policy_training_preset and not self.resume:
self.optimizer = self.policy.get_optimizer_preset()
self.scheduler = self.policy.get_scheduler_preset()
self.optimizer = active_cfg.get_optimizer_preset()
self.scheduler = active_cfg.get_scheduler_preset()
if self.policy.push_to_hub and not self.policy.repo_id:
raise ValueError(
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
if self.use_rabc and not self.rabc_progress_path:
# Auto-detect from dataset path
repo_id = self.dataset.repo_id
if self.dataset.root:
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
else:
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
"""Keys for draccus pretrained-path loading."""
return ["policy", "reward_model"]
def to_dict(self) -> dict[str, Any]:
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
+7 -4
View File
@@ -19,6 +19,7 @@ from pprint import pformat
import torch
from lerobot.configs import PreTrainedConfig
from lerobot.configs.rewards import RewardModelConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
@@ -30,12 +31,14 @@ from .streaming_dataset import StreamingLeRobotDataset
def resolve_delta_timestamps(
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata
) -> dict[str, list] | None:
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the config.
Args:
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both
``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the
``{observation,action,reward}_delta_indices`` properties used below.
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
delta_timestamps against.
@@ -82,7 +85,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta)
if not cfg.dataset.streaming:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
-4
View File
@@ -24,8 +24,6 @@ from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .sac.configuration_sac import SACConfig as SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .utils import make_robot_action, prepare_observation_for_inference
@@ -46,9 +44,7 @@ __all__ = [
"PI0Config",
"PI0FastConfig",
"PI05Config",
"RewardClassifierConfig",
"SACConfig",
"SARMConfig",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
+2 -30
View File
@@ -52,8 +52,6 @@ from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config
from .pretrained import PreTrainedPolicy
from .sac.configuration_sac import SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
@@ -89,7 +87,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -132,18 +130,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .sac.modeling_sac import SACPolicy
return SACPolicy
elif name == "reward_classifier":
from .sac.reward_model.modeling_classifier import Classifier
return Classifier
elif name == "smolvla":
from .smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "sarm":
from .sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "groot":
from .groot.modeling_groot import GrootPolicy
@@ -173,7 +163,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "reward_classifier", "wall_x".
"smolvla", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -200,8 +190,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SACConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
elif policy_type == "xvla":
@@ -378,14 +366,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, RewardClassifierConfig):
from .sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SmolVLAConfig):
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
@@ -394,14 +374,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SARMConfig):
from .sarm.processor_sarm import make_sarm_pre_post_processors
processors = make_sarm_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, GrootConfig):
from .groot.processor_groot import make_groot_pre_post_processors
-1
View File
@@ -1 +0,0 @@
../../../../docs/source/policy_sarm_README.md
-18
View File
@@ -1,18 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_sarm import SARMConfig
from .modeling_sarm import SARMRewardModel
__all__ = ["SARMConfig", "SARMRewardModel"]
+1 -1
View File
@@ -557,7 +557,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
def __post_init__(self):
"""Initializes the reward classifier model after the dataclass is created."""
if self.pretrained_path is not None:
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
self.reward_classifier.to(self.device)
+36
View File
@@ -0,0 +1,36 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .factory import (
get_reward_model_class as get_reward_model_class,
make_reward_model as make_reward_model,
make_reward_model_config as make_reward_model_config,
make_reward_pre_post_processors as make_reward_pre_post_processors,
)
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
from .sarm.configuration_sarm import SARMConfig as SARMConfig
__all__ = [
# Configuration classes
"RewardClassifierConfig",
"SARMConfig",
# Base class
"PreTrainedRewardModel",
# Factory functions
"get_reward_model_class",
"make_reward_model",
"make_reward_model_config",
"make_reward_pre_post_processors",
]
@@ -1,5 +1,3 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,14 +13,15 @@
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs import NormalizationMode, PreTrainedConfig
from lerobot.configs import NormalizationMode
from lerobot.configs.rewards import RewardModelConfig
from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig
from lerobot.utils.constants import OBS_IMAGE
@PreTrainedConfig.register_subclass(name="reward_classifier")
@RewardModelConfig.register_subclass(name="reward_classifier")
@dataclass
class RewardClassifierConfig(PreTrainedConfig):
class RewardClassifierConfig(RewardModelConfig):
"""Configuration for the Reward Classifier model."""
name: str = "reward_classifier"
@@ -1,5 +1,3 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,11 +17,10 @@ import logging
import torch
from torch import Tensor, nn
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.utils.constants import OBS_IMAGE, REWARD
from ...pretrained import PreTrainedPolicy
from .configuration_classifier import RewardClassifierConfig
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
@@ -99,7 +96,7 @@ class SpatialLearnedEmbeddings(nn.Module):
return output
class Classifier(PreTrainedPolicy):
class Classifier(PreTrainedRewardModel):
"""Image classifier built on top of a pre-trained encoder."""
name = "reward_classifier"
@@ -235,6 +232,16 @@ class Classifier(PreTrainedPolicy):
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
"""Returns 1.0 for success, 0.0 for failure based on image observations."""
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
output = self.predict(images)
if self.config.num_classes == 2:
return (output.probabilities > 0.5).float()
else:
return torch.argmax(output.probabilities, dim=1).float()
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
# Extract images and labels
@@ -269,10 +276,6 @@ class Classifier(PreTrainedPolicy):
def predict_reward(self, batch, threshold=0.5):
"""Eval method. Returns predicted reward with the decision threshold as argument."""
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images from batch dict
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
@@ -282,28 +285,3 @@ class Classifier(PreTrainedPolicy):
return (probs > threshold).float()
else:
return torch.argmax(self.predict(images).probabilities, dim=1)
def get_optim_params(self):
"""Return optimizer parameters for the policy."""
return self.parameters()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
raise NotImplementedError("Reward classifiers do not select actions")
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not produce action chunks.
"""
raise NotImplementedError("Reward classifiers do not predict action chunks")
def reset(self):
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not select actions.
"""
pass
@@ -1,5 +1,3 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -27,8 +25,7 @@ from lerobot.processor import (
policy_action_to_transition,
transition_to_policy_action,
)
from .configuration_classifier import RewardClassifierConfig
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
def make_classifier_processor(
@@ -52,8 +49,6 @@ def make_classifier_processor(
Args:
config: The configuration object for the RewardClassifier.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
+238
View File
@@ -0,0 +1,238 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
from typing import Any
import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
"""
Retrieves a reward model class by its registered name.
This function uses dynamic imports to avoid loading all reward model classes into
memory at once, improving startup time and reducing dependencies.
Args:
name: The name of the reward model. Supported names are "reward_classifier",
"sarm".
Returns:
The reward model class corresponding to the given name.
Raises:
ValueError: If the reward model name is not recognized.
"""
if name == "reward_classifier":
from lerobot.rewards.classifier.modeling_classifier import Classifier
return Classifier
elif name == "sarm":
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
else:
try:
return _get_reward_model_cls_from_name(name=name)
except Exception as e:
raise ValueError(f"Reward model type '{name}' is not available.") from e
def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
"""
Instantiates a reward model configuration object based on the reward type.
This factory function simplifies the creation of reward model configuration objects
by mapping a string identifier to the corresponding config class.
Args:
reward_type: The type of the reward model. Supported types include
"reward_classifier", "sarm".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
An instance of a `RewardModelConfig` subclass.
Raises:
ValueError: If the `reward_type` is not recognized.
"""
if reward_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif reward_type == "sarm":
return SARMConfig(**kwargs)
else:
try:
config_cls = RewardModelConfig.get_choice_class(reward_type)
return config_cls(**kwargs)
except Exception as e:
raise ValueError(f"Reward model type '{reward_type}' is not available.") from e
def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel:
"""
Instantiate a reward model from its configuration.
Args:
cfg: The configuration for the reward model to be created. If
`cfg.pretrained_path` is set, the model will be loaded with weights
from that path.
**kwargs: Additional keyword arguments forwarded to the model constructor
(e.g., ``dataset_stats``, ``dataset_meta``).
Returns:
An instantiated and device-placed reward model.
"""
reward_cls = get_reward_model_class(cfg.type)
kwargs["config"] = cfg
if cfg.pretrained_path:
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
reward_model = reward_cls.from_pretrained(**kwargs)
else:
reward_model = reward_cls(**kwargs)
reward_model.to(cfg.device)
assert isinstance(reward_model, torch.nn.Module)
return reward_model
def make_reward_pre_post_processors(
reward_cfg: RewardModelConfig,
**kwargs,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Create pre- and post-processor pipelines for a given reward model.
Each reward model type has a dedicated factory function for its processors.
Args:
reward_cfg: The configuration of the reward model for which to create processors.
**kwargs: Additional keyword arguments passed to the processor factory
(e.g., ``dataset_stats``, ``dataset_meta``).
Returns:
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
Raises:
ValueError: If a processor factory is not implemented for the given reward
model configuration type.
"""
# Create a new processor based on reward model type
if isinstance(reward_cfg, RewardClassifierConfig):
from lerobot.rewards.classifier.processor_classifier import make_classifier_processor
return make_classifier_processor(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(reward_cfg, SARMConfig):
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
return make_sarm_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
else:
try:
processors = _make_processors_from_reward_model_config(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
except Exception as e:
raise ValueError(
f"Processor for reward model type '{reward_cfg.type}' is not implemented."
) from e
return processors
def _get_reward_model_cls_from_name(name: str) -> type[PreTrainedRewardModel]:
"""Get reward model class from its registered name using dynamic imports.
This is used as a helper function to import reward models from 3rd party lerobot
plugins.
Args:
name: The name of the reward model.
Returns:
The reward model class corresponding to the given name.
"""
if name not in RewardModelConfig.get_known_choices():
raise ValueError(
f"Unknown reward model name '{name}'. "
f"Available reward models: {RewardModelConfig.get_known_choices()}"
)
config_cls = RewardModelConfig.get_choice_class(name)
config_cls_name = config_cls.__name__
model_name = config_cls_name.removesuffix("Config")
if model_name == config_cls_name:
raise ValueError(
f"The config class name '{config_cls_name}' does not follow the expected naming convention. "
f"Make sure it ends with 'Config'!"
)
cls_name = model_name + "RewardModel"
module_path = config_cls.__module__.replace("configuration_", "modeling_")
module = importlib.import_module(module_path)
reward_cls = getattr(module, cls_name)
return reward_cls
def _make_processors_from_reward_model_config(
config: RewardModelConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[Any, Any]:
"""Create pre- and post-processors from a reward model configuration using dynamic imports.
This is used as a helper function to import processor factories from 3rd party
lerobot reward model plugins.
Args:
config: The reward model configuration object.
dataset_stats: Dataset statistics for normalization.
Returns:
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
"""
reward_type = config.type
function_name = f"make_{reward_type}_pre_post_processors"
module_path = config.__class__.__module__.replace("configuration_", "processor_")
logging.debug(
f"Instantiating reward pre/post processors using function '{function_name}' "
f"from module '{module_path}'"
)
module = importlib.import_module(module_path)
function = getattr(module, function_name)
return function(config, dataset_stats=dataset_stats)
+244
View File
@@ -0,0 +1,244 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import builtins
import logging
import os
from importlib.resources import files
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, TypeVar
import packaging
import safetensors
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.configs.rewards import RewardModelConfig
from lerobot.utils.hub import HubMixin
if TYPE_CHECKING:
from lerobot.configs.train import TrainPipelineConfig
T = TypeVar("T", bound="PreTrainedRewardModel")
class PreTrainedRewardModel(nn.Module, HubMixin, abc.ABC):
"""Base class for reward models."""
config_class: None
name: None
def __init__(self, config: RewardModelConfig, *inputs, **kwargs):
super().__init__()
if not isinstance(config, RewardModelConfig):
raise ValueError(
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
"`RewardModelConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.config = config
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not getattr(cls, "config_class", None):
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'")
def _save_pretrained(self, save_directory: Path) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: RewardModelConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs,
) -> T:
"""
The reward model is set in evaluation mode by default using `reward.eval()` (dropout modules are
deactivated). To train it, you should first set it back in training mode with `reward.train()`.
"""
if config is None:
config = RewardModelConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
reward.to(config.device)
reward.eval()
return reward
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
# Create base kwargs
kwargs = {"strict": strict}
# Add device parameter for newer versions that support it
if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
kwargs["device"] = map_location
# Load the model with appropriate kwargs
missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
if missing_keys:
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
# For older versions, manually move to device if needed
if "device" not in kwargs and map_location != "cpu":
logging.warning(
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Please update safetensors to version 0.4.3 or above for improved performance."
)
model.to(map_location)
return model
def get_optim_params(self):
"""
Returns the reward-model-specific parameters dict to be passed on to the optimizer.
"""
return self.parameters()
def reset(self) -> None:
"""Reset any internal state."""
pass
@abc.abstractmethod
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
"""Compute a scalar reward signal for a batch of observations.
Args:
batch: Dictionary containing at minimum observation tensors.
May also contain "action", "next_observation.*", etc.
Returns:
Tensor of shape ``(batch_size,)`` with reward values.
"""
...
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
"""Training forward pass — override for trainable reward models."""
raise NotImplementedError(
f"{self.__class__.__name__} is not trainable. Only use compute_reward() for inference."
)
@property
def is_trainable(self) -> bool:
"""Whether this reward model can be trained via ``lerobot-train``.
Trainable reward models override :meth:`forward`; zero-shot models
inherit the base implementation that raises ``NotImplementedError``.
"""
return type(self).forward is not PreTrainedRewardModel.forward
def push_model_to_hub(self, cfg: "TrainPipelineConfig"):
api = HfApi()
repo_id = api.create_repo(
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
).repo_id
# Push the files to the repo in a single commit
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
)
card.save(str(saved_path / "README.md"))
cfg.save_pretrained(saved_path) # Calls _save_pretrained and stores train config
commit_info = api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=saved_path,
commit_message="Upload reward model weights, train config and readme",
allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.md"],
ignore_patterns=["*.tmp", "*.log"],
)
logging.info(f"Model pushed to {commit_info.repo_url.url}")
def generate_model_card(
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
) -> ModelCard:
card_data = ModelCardData(
license=license or "apache-2.0",
library_name="lerobot",
pipeline_tag="robotics",
tags=list(set(tags or []).union({"robotics", "lerobot", "reward-model", model_type})),
model_name=model_type,
datasets=dataset_repo_id,
)
template_card = (
files("lerobot.templates")
.joinpath("lerobot_rewardmodel_modelcard_template.md")
.read_text(encoding="utf-8")
)
card = ModelCard.from_template(card_data, template_str=template_card)
card.validate()
return card
@@ -25,18 +25,18 @@ need ~num_frames/30 queries instead of one per frame (~30x speedup).
Usage:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4
# Faster computation with stride (compute every 5 frames, interpolate the rest)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--stride 5
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--visualize-only \\
@@ -58,10 +58,9 @@ import torch
from tqdm import tqdm
from lerobot.datasets import LeRobotDataset
from .modeling_sarm import SARMRewardModel
from .processor_sarm import make_sarm_pre_post_processors
from .sarm_utils import normalize_stage_tau
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
from lerobot.rewards.sarm.sarm_utils import normalize_stage_tau
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
@@ -713,12 +712,12 @@ def main():
epilog="""
Examples:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--visualize-only \\
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
# and The HuggingFace Inc. team. All rights reserved.
#
@@ -22,14 +20,15 @@ Paper: https://arxiv.org/abs/2509.25358
from dataclasses import dataclass, field
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.rewards import RewardModelConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
@PreTrainedConfig.register_subclass("sarm")
@RewardModelConfig.register_subclass("sarm")
@dataclass
class SARMConfig(PreTrainedConfig):
class SARMConfig(RewardModelConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling).
Supports three annotation modes:
@@ -110,7 +109,6 @@ class SARMConfig(PreTrainedConfig):
def __post_init__(self):
super().__post_init__()
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
raise ValueError(
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
# and The HuggingFace Inc. team. All rights reserved.
#
@@ -34,14 +32,13 @@ import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from lerobot.utils.constants import OBS_STR
from ..pretrained import PreTrainedPolicy
from .configuration_sarm import SARMConfig
from .sarm_utils import (
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
from lerobot.rewards.sarm.sarm_utils import (
normalize_stage_tau,
pad_state_to_max_dim,
)
from lerobot.utils.constants import OBS_STR
class StageTransformer(nn.Module):
@@ -353,7 +350,7 @@ def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
return stage_onehot
class SARMRewardModel(PreTrainedPolicy):
class SARMRewardModel(PreTrainedRewardModel):
"""
SARM Reward Model for stage-aware task completion rewards.
@@ -471,6 +468,23 @@ class SARMRewardModel(PreTrainedPolicy):
self.subtask_model.to(device)
return self
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
"""Compute dense progress reward in [0, 1] from batch.
Expects batch to contain:
- "observation_features" or video embeddings: (B, T, 512)
- "language_embedding" or text embeddings: (B, 512)
- optionally "observation.state": (B, T, state_dim)
"""
text_emb = batch.get("language_embedding", batch.get("text_features"))
video_emb = batch.get("observation_features", batch.get("video_features"))
state = batch.get("observation.state", batch.get("state_features"))
rewards = self.calculate_rewards(text_emb, video_emb, state)
if isinstance(rewards, np.ndarray):
rewards = torch.from_numpy(rewards).float()
return rewards
@torch.no_grad()
def calculate_rewards(
self,
@@ -631,17 +645,9 @@ class SARMRewardModel(PreTrainedPolicy):
return self.parameters()
def reset(self):
"""Required by PreTrainedPolicy but not used for reward models."""
"""SARM has no episode-level state to reset."""
pass
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for reward models."""
raise NotImplementedError("SARM model does not predict action chunks")
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for SARM."""
raise NotImplementedError("SARM model does not select actions")
def _train_step(
self,
img_emb: torch.Tensor, # (B, N, T, D)
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -60,16 +58,15 @@ from lerobot.processor import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from .configuration_sarm import SARMConfig
from .sarm_utils import (
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
from lerobot.rewards.sarm.sarm_utils import (
apply_rewind_augmentation,
compute_absolute_indices,
find_stage_and_tau,
pad_state_to_max_dim,
)
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
class SARMEncodingProcessorStep(ProcessorStep):
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,14 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation.
This module implements the SampleWeighter protocol for RA-BC training,
which weights training samples based on their task progress as measured
by the SARM reward model.
The weights are computed based on progress deltas:
delta = progress[t + chunk_size] - progress[t]
High-quality samples (positive progress) get higher weights, while
samples with negative progress (going backwards) get zero weight.
See: https://arxiv.org/abs/2509.25358 for the SARM paper.
"""
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from lerobot.utils.import_utils import _pandas_available
from lerobot.utils.sample_weighting import SampleWeighter
if TYPE_CHECKING or _pandas_available:
import pandas as pd
else:
pd = None # type: ignore[assignment]
def resolve_hf_path(path: str | Path) -> Path:
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
@@ -34,23 +56,27 @@ def resolve_hf_path(path: str | Path) -> Path:
return Path(path)
class RABCWeights:
class RABCWeights(SampleWeighter):
"""
Load precomputed SARM progress values and compute RA-BC weights during training.
This class implements the SampleWeighter ABC for use with the generic
sample weighting infrastructure in lerobot.
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
During training, computes:
- progress_delta = progress[t + chunk_size] - progress[t]
- rabc_weight based on the delta (paper Eq. 8-9)
Args:
progress_path: Path to parquet file with precomputed progress values
chunk_size: Number of frames ahead for computing progress delta
head_mode: Which SARM head to use ("sparse" or "dense")
kappa: Hard threshold for high-quality samples (default: 0.01)
epsilon: Small constant for numerical stability (default: 1e-6)
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
device: Device to return tensors on
progress_path: Path to parquet file with precomputed progress values.
Supports HuggingFace URLs (hf://datasets/...).
chunk_size: Number of frames ahead for computing progress delta.
head_mode: Which SARM head to use ("sparse" or "dense").
kappa: Hard threshold for high-quality samples (default: 0.01).
epsilon: Small constant for numerical stability (default: 1e-6).
fallback_weight: Weight to use for frames without valid delta (default: 1.0).
device: Device to return tensors on.
"""
def __init__(
@@ -61,7 +87,7 @@ class RABCWeights:
kappa: float = 0.01,
epsilon: float = 1e-6,
fallback_weight: float = 1.0,
device: torch.device = None,
device: torch.device | None = None,
):
self.progress_path = resolve_hf_path(progress_path)
self.chunk_size = chunk_size
@@ -87,8 +113,8 @@ class RABCWeights:
logging.info(f"Using progress column: {self.progress_column}")
self.progress_lookup = {}
self.episode_lookup = {}
self.progress_lookup: dict[int, float] = {}
self.episode_lookup: dict[int, int] = {}
for _, row in self.df.iterrows():
global_idx = int(row["index"])
@@ -100,7 +126,7 @@ class RABCWeights:
self.episode_lookup[global_idx] = episode_idx
# Build episode boundaries for delta computation
self.episode_boundaries = {}
self.episode_boundaries: dict[int, dict[str, int]] = {}
for episode_idx in self.df["episode_index"].unique():
ep_df = self.df[self.df["episode_index"] == episode_idx]
self.episode_boundaries[int(episode_idx)] = {
@@ -114,7 +140,7 @@ class RABCWeights:
# Compute global statistics for weight computation
self._compute_global_stats()
def _compute_global_stats(self):
def _compute_global_stats(self) -> None:
"""Compute global mean and std of progress deltas for weight calculation."""
all_deltas = []
@@ -138,8 +164,8 @@ class RABCWeights:
all_deltas.append(delta)
if all_deltas:
self.delta_mean = max(np.mean(all_deltas), 0.0)
self.delta_std = max(np.std(all_deltas), self.epsilon)
self.delta_mean = max(float(np.mean(all_deltas)), 0.0)
self.delta_std = max(float(np.std(all_deltas)), self.epsilon)
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
else:
self.delta_mean = 0.0
@@ -157,18 +183,19 @@ class RABCWeights:
4. Compute weight using paper Eq. 8-9
Args:
batch: Training batch containing "index" key with global frame indices
batch: Training batch containing "index" key with global frame indices.
Returns:
Tuple of:
- Weights tensor (batch_size,) normalized to sum to batch_size
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
- Weights tensor (batch_size,) normalized to sum to batch_size.
- Stats dict with weighting statistics for logging.
"""
indices = batch.get("index")
if indices is None:
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
batch_size = self._get_batch_size(batch)
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size}
return torch.ones(batch_size, device=self.device), stats
# Convert to list of ints
if isinstance(indices, torch.Tensor):
@@ -183,29 +210,29 @@ class RABCWeights:
delta = self._compute_delta(idx)
deltas.append(delta)
deltas = np.array(deltas, dtype=np.float32)
deltas_array = np.array(deltas, dtype=np.float32)
# Compute weights from deltas
weights = self._compute_weights(deltas)
weights = self._compute_weights(deltas_array)
# Compute stats before normalization for logging
raw_mean_weight = float(np.nanmean(weights))
num_zero_weight = int(np.sum(weights == 0))
num_full_weight = int(np.sum(weights == 1.0))
batch_stats = {
"raw_mean_weight": raw_mean_weight,
"mean_weight": raw_mean_weight,
"num_zero_weight": num_zero_weight,
"num_full_weight": num_full_weight,
}
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32)
# Normalize to sum to batch_size
batch_size = len(weights)
weight_sum = weights.sum() + self.epsilon
weights = weights * batch_size / weight_sum
batch_size = len(weights_tensor)
weight_sum = weights_tensor.sum() + self.epsilon
weights_tensor = weights_tensor * batch_size / weight_sum
return weights, batch_stats
return weights_tensor, batch_stats
def _compute_delta(self, global_idx: int) -> float:
"""Compute progress delta for a single frame."""
@@ -241,7 +268,7 @@ class RABCWeights:
- Final weight: wi = 1{ri > κ} + 1{0 ri κ}˜wi
Returns:
Array of weights
Array of weights.
"""
valid_mask = ~np.isnan(deltas)
@@ -273,12 +300,13 @@ class RABCWeights:
if key in batch:
val = batch[key]
if isinstance(val, (torch.Tensor, np.ndarray)):
return val.shape[0]
return int(val.shape[0])
return 1
def get_stats(self) -> dict:
"""Get statistics."""
"""Get global statistics about the RA-BC weighting."""
return {
"type": "rabc",
"num_frames": len(self.progress_lookup),
"chunk_size": self.chunk_size,
"head_mode": self.head_mode,
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
+95 -82
View File
@@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.rewards import make_reward_pre_post_processors
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
@@ -70,8 +71,8 @@ def update_policy(
accelerator: "Accelerator",
lr_scheduler=None,
lock=None,
rabc_weights_provider=None,
) -> tuple[MetricsTracker, dict]:
sample_weighter=None,
) -> tuple[MetricsTracker, dict | None]:
"""
Performs a single training step to update the policy's weights.
@@ -87,7 +88,7 @@ def update_policy(
accelerator: The Accelerator instance for distributed training and mixed precision.
lr_scheduler: An optional learning rate scheduler.
lock: An optional lock for thread-safe optimizer updates.
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
sample_weighter: Optional SampleWeighter instance for per-sample loss weighting.
Returns:
A tuple containing:
@@ -97,27 +98,31 @@ def update_policy(
start_time = time.perf_counter()
policy.train()
# Get RA-BC weights if enabled
rabc_batch_weights = None
rabc_batch_stats = None
if rabc_weights_provider is not None:
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
# Compute sample weights if a weighter is provided
sample_weights = None
weight_stats = None
if sample_weighter is not None:
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
# Let accelerator handle mixed precision
with accelerator.autocast():
# Use per-sample loss when RA-BC is enabled for proper weighting
if rabc_batch_weights is not None:
# Get per-sample losses
if sample_weights is not None:
# Use per-sample loss for weighted training
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
# rabc_batch_weights is already normalized to sum to batch_size
# Weighted loss: each sample's contribution is scaled by its weight.
# We divide by weight sum (not batch size) so that if some weights are zero,
# the remaining samples contribute proportionally more, preserving gradient scale.
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
epsilon = 1e-6
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
# Log raw mean weight (before normalization) - this is the meaningful metric
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
# Log weighting statistics
if output_dict is None:
output_dict = {}
for key, value in weight_stats.items():
output_dict[f"sample_weight_{key}"] = value
else:
loss, output_dict = policy.forward(batch)
@@ -188,8 +193,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when policy.device is set to CPU.
force_cpu = cfg.policy.device == "cpu"
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
force_cpu = cfg.trainable_config.device == "cpu"
accelerator = Accelerator(
step_scheduler_with_optimizer=False,
kwargs_handlers=[ddp_kwargs],
@@ -245,26 +250,44 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
if is_main_process:
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
rename_map=cfg.rename_map,
)
if cfg.is_reward_model_training:
if is_main_process:
logging.info("Creating reward model")
from lerobot.rewards import make_reward_model
policy = make_reward_model(
cfg=cfg.reward_model,
dataset_stats=dataset.meta.stats,
dataset_meta=dataset.meta,
)
if not policy.is_trainable:
raise ValueError(
f"Reward model '{policy.name}' is zero-shot and cannot be trained via lerobot-train. "
"Use it directly for inference via compute_reward() (e.g. offline precompute)."
)
else:
if is_main_process:
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
rename_map=cfg.rename_map,
)
if cfg.peft is not None:
if cfg.is_reward_model_training:
raise ValueError("PEFT is only supported for policy training. ")
logging.info("Using PEFT! Wrapping model.")
# Convert CLI peft config to dict for overrides
peft_cli_overrides = dataclasses.asdict(cfg.peft)
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
# Wait for all processes to finish policy creation before continuing
# Wait for all processes to finish model creation before continuing
accelerator.wait_for_everyone()
processor_pretrained_path = cfg.policy.pretrained_path
active_cfg = cfg.trainable_config
processor_pretrained_path = active_cfg.pretrained_path
if (
getattr(cfg.policy, "use_relative_actions", False)
getattr(active_cfg, "use_relative_actions", False)
and processor_pretrained_path is not None
and not cfg.resume
):
@@ -274,18 +297,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
)
processor_pretrained_path = None
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
postprocessor_kwargs = {}
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For SARM, always provide dataset_meta for progress normalization
if cfg.policy.type == "sarm":
if cfg.is_reward_model_training:
processor_kwargs["dataset_meta"] = dataset.meta
if processor_pretrained_path is not None:
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
"normalizer_processor": {
@@ -305,38 +325,36 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
},
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
if cfg.is_reward_model_training:
preprocessor, postprocessor = make_reward_pre_post_processors(
cfg.reward_model,
**processor_kwargs,
)
else:
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
if is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
# Load precomputed SARM progress for RA-BC if enabled
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
rabc_weights = None
if cfg.use_rabc:
from lerobot.utils.rabc import RABCWeights
# Create sample weighter if configured (e.g., for RA-BC training)
sample_weighter = None
if cfg.sample_weighting is not None:
from lerobot.utils.sample_weighting import make_sample_weighter
# Get chunk_size from policy config
chunk_size = getattr(policy.config, "chunk_size", None)
if chunk_size is None:
raise ValueError("Chunk size is not found in policy config")
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
rabc_weights = RABCWeights(
progress_path=cfg.rabc_progress_path,
chunk_size=chunk_size,
head_mode=head_mode,
kappa=getattr(cfg, "rabc_kappa", 0.01),
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
device=device,
if is_main_process:
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
sample_weighter = make_sample_weighter(
cfg.sample_weighting,
policy,
device,
dataset_root=cfg.dataset.root,
dataset_repo_id=cfg.dataset.repo_id,
)
step = 0 # number of policy updates (forward + backward + optim)
@@ -365,13 +383,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
)
else:
@@ -448,7 +466,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weights_provider=rabc_weights,
sample_weighter=sample_weighter,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -467,16 +485,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weights is not None:
rabc_stats = rabc_weights.get_stats()
wandb_log_dict.update(
{
"rabc_delta_mean": rabc_stats["delta_mean"],
"rabc_delta_std": rabc_stats["delta_std"],
"rabc_num_frames": rabc_stats["num_frames"],
}
)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
@@ -558,14 +570,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process:
logging.info("End of training")
if cfg.policy.push_to_hub:
unwrapped_policy = accelerator.unwrap_model(policy)
if cfg.policy.use_peft:
unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy)
if getattr(active_cfg, "push_to_hub", False):
unwrapped_model = accelerator.unwrap_model(policy)
# PEFT only applies when training a policy — reward models use the plain path.
if not cfg.is_reward_model_training and cfg.policy.use_peft:
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
else:
unwrapped_policy.push_model_to_hub(cfg)
preprocessor.push_to_hub(cfg.policy.repo_id)
postprocessor.push_to_hub(cfg.policy.repo_id)
unwrapped_model.push_model_to_hub(cfg)
preprocessor.push_to_hub(active_cfg.repo_id)
postprocessor.push_to_hub(active_cfg.repo_id)
# Properly clean up the distributed process group
accelerator.wait_for_everyone()
@@ -0,0 +1,55 @@
---
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/model-cards
# prettier-ignore
{{card_data}}
---
# Reward Model Card for {{ model_name | default("Reward Model ID", true) }}
<!-- Provide a quick summary of what the reward model is/does. -->
{% if model_name == "reward_classifier" %}
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
{% elif model_name == "sarm" %}
A Success-Aware Reward Model (SARM) predicts a dense reward signal from observations, typically used downstream for reinforcement learning or human-in-the-loop fine-tuning when task success is not directly observable.
{% else %}
_Reward model type not recognized — please update this template._
{% endif %}
This reward model has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot).
See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index).
---
## How to Get Started with the Reward Model
### Train from scratch
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/<dataset> \
--reward_model.type={{ model_name | default("reward_classifier", true) }} \
--output_dir=outputs/train/<desired_reward_model_repo_id> \
--job_name=lerobot_reward_training \
--reward_model.device=cuda \
--reward_model.repo_id=${HF_USER}/<desired_reward_model_repo_id> \
--wandb.enable=true
```
_Writes checkpoints to `outputs/train/<desired_reward_model_repo_id>/checkpoints/`._
### Load the reward model in Python
```python
from lerobot.rewards import make_reward_model
reward_model = make_reward_model(pretrained_path="<hf_user>/<reward_model_repo_id>")
reward = reward_model.compute_reward(batch)
```
---
## Model Details
- **License:** {{ license | default("\[More Information Needed]", true) }}
+239
View File
@@ -0,0 +1,239 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sample weighting abstraction for training.
This module provides an abstract base class for sample weighting strategies (e.g., RA-BC)
that can be used during training without polluting the training script with
policy-specific code.
Example usage:
# In training config
sample_weighting:
type: rabc
progress_path: hf://datasets/my-dataset/sarm_progress.parquet
head_mode: sparse
kappa: 0.01
# In training script
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device, dataset_root=cfg.dataset.root, dataset_repo_id=cfg.dataset.repo_id)
...
weights, stats = sample_weighter.compute_batch_weights(batch)
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import torch
if TYPE_CHECKING:
from lerobot.policies.pretrained import PreTrainedPolicy
class SampleWeighter(ABC):
"""
Implementations compute per-sample weights that can be used to weight
the loss during training. This enables techniques like:
- RA-BC (Reward-Aligned Behavior Cloning)
- Importance sampling
- Curriculum learning
- Quality-based filtering
"""
@abstractmethod
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""
Compute per-sample weights for a training batch.
Args:
batch: Training batch dictionary containing at minimum an "index" key
with global frame indices.
"""
@abstractmethod
def get_stats(self) -> dict:
"""
Get global statistics about the weighting strategy.
"""
@dataclass
class SampleWeightingConfig:
"""
Configuration for sample weighting during training.
This is a generic config that supports multiple weighting strategies.
The `type` field determines which implementation to use, and `extra_params`
contains additional type-specific parameters.
Attributes:
type: Weighting strategy type ("rabc", "uniform", etc.)
progress_path: Path to precomputed progress values (for RABC)
head_mode: Which model head to use for progress ("sparse" or "dense")
kappa: Hard threshold for high-quality samples (RABC-specific)
epsilon: Small constant for numerical stability
extra_params: Additional type-specific parameters passed to the weighter
"""
type: str = "rabc"
progress_path: str | None = None
head_mode: str = "sparse"
kappa: float = 0.01
epsilon: float = 1e-6
# Additional type-specific params can be added here or passed via extra_params
extra_params: dict = field(default_factory=dict)
def make_sample_weighter(
config: SampleWeightingConfig | None,
policy: PreTrainedPolicy,
device: torch.device,
dataset_root: str | None = None,
dataset_repo_id: str | None = None,
) -> SampleWeighter | None:
"""
Factory function to create a SampleWeighter from config.
This keeps policy-specific initialization logic out of the training script.
Args:
config: Sample weighting configuration, or None to disable weighting.
policy: The policy being trained (used to extract chunk_size, etc.)
device: Device to place weight tensors on.
dataset_root: Local path to dataset root (for auto-detecting progress_path).
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
"""
if config is None:
return None
if config.type == "rabc":
return _make_rabc_weighter(config, policy, device, dataset_root, dataset_repo_id)
if config.type == "uniform":
# No-op weighter that returns uniform weights
return UniformWeighter(device=device)
raise ValueError(f"Unknown sample weighting type: '{config.type}'. Supported types: 'rabc', 'uniform'")
def _make_rabc_weighter(
config: SampleWeightingConfig,
policy: PreTrainedPolicy,
device: torch.device,
dataset_root: str | None = None,
dataset_repo_id: str | None = None,
) -> SampleWeighter:
"""Create RABC weighter with policy-specific initialization.
Args:
config: Sample weighting configuration.
policy: The policy being trained (used to extract chunk_size).
device: Device to place weight tensors on.
dataset_root: Local path to dataset root (for auto-detecting progress_path).
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
"""
# Import here to avoid circular imports and keep RABC code in SARM module
from lerobot.rewards.sarm.rabc import RABCWeights
# Extract chunk_size from policy config
chunk_size = getattr(policy.config, "chunk_size", None)
if chunk_size is None:
raise ValueError(
"RABC sample weighting requires a policy with 'chunk_size' in its config. "
"This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc."
)
# Determine progress_path: use explicit config or auto-detect from dataset
progress_path = config.progress_path
if progress_path is None:
if dataset_root:
progress_path = str(Path(dataset_root) / "sarm_progress.parquet")
elif dataset_repo_id:
progress_path = f"hf://datasets/{dataset_repo_id}/sarm_progress.parquet"
else:
raise ValueError(
"RABC sample weighting requires 'progress_path' to be set, "
"or dataset_root/dataset_repo_id for auto-detection. "
"Generate progress values using: "
"python -m lerobot.rewards.sarm.compute_rabc_weights --help"
)
return RABCWeights(
progress_path=progress_path,
chunk_size=chunk_size,
head_mode=config.head_mode,
kappa=config.kappa,
epsilon=config.epsilon,
device=device,
**config.extra_params,
)
class UniformWeighter(SampleWeighter):
"""
No-op sample weighter that returns uniform weights.
Useful as a baseline or when you want to disable weighting without
changing the training code structure.
Note:
Batch size is determined by looking for tensor values in the batch
dictionary. The method checks common keys like "action", "index",
and "observation.state" first, then falls back to scanning all values.
"""
def __init__(self, device: torch.device):
self.device = device
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""Return uniform weights (all ones)."""
batch_size = self._determine_batch_size(batch)
weights = torch.ones(batch_size, device=self.device)
stats = {"mean_weight": 1.0, "type": "uniform"}
return weights, stats
def _determine_batch_size(self, batch: dict) -> int:
"""
Determine batch size from the batch dictionary.
Checks common keys first, then scans all values for tensors.
Args:
batch: Training batch dictionary.
"""
if not batch:
raise ValueError("Cannot determine batch size from empty batch")
# Check common keys first
for key in ["action", "index", "observation.state"]:
if key in batch and isinstance(batch[key], torch.Tensor):
return batch[key].shape[0]
# Scan all values for any tensor
for value in batch.values():
if isinstance(value, torch.Tensor) and value.ndim >= 1:
return value.shape[0]
# Last resort: return 1 (this handles non-tensor batches)
return 1
def get_stats(self) -> dict:
"""Return empty stats for uniform weighting."""
return {"type": "uniform"}
View File
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,8 +19,6 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
from lerobot.processor import (
DataProcessorPipeline,
DeviceProcessorStep,
@@ -31,6 +27,8 @@ from lerobot.processor import (
TransitionKey,
)
from lerobot.processor.converters import create_transition, transition_to_batch
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.classifier.processor_classifier import make_classifier_processor
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
@@ -42,7 +40,7 @@ def create_default_config():
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), # Classifier output
"reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
}
config.normalization_mapping = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
@@ -90,17 +88,14 @@ def test_classifier_processor_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Create test data
observation = {
OBS_STATE: torch.randn(10),
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(1) # Dummy action/reward
action = torch.randn(1)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
@@ -120,10 +115,7 @@ def test_classifier_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Create CPU data
observation = {
@@ -132,7 +124,6 @@ def test_classifier_processor_cuda():
}
action = torch.randn(1)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
# Process through preprocessor
@@ -158,10 +149,7 @@ def test_classifier_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -171,7 +159,6 @@ def test_classifier_processor_accelerate_scenario():
}
action = torch.randn(1).to(device)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
# Process through preprocessor
@@ -201,7 +188,6 @@ def test_classifier_processor_multi_gpu():
}
action = torch.randn(1).to(device)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
# Process through preprocessor
@@ -231,7 +217,6 @@ def test_classifier_processor_without_stats():
}
action = torch.randn(1)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
processed = preprocessor(batch)
@@ -294,7 +279,6 @@ def test_classifier_processor_mixed_precision():
}
action = torch.randn(1, dtype=torch.float32)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
# Process through preprocessor
@@ -312,10 +296,7 @@ def test_classifier_processor_batch_data():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Test with batched data
batch_size = 16
@@ -325,7 +306,6 @@ def test_classifier_processor_batch_data():
}
action = torch.randn(batch_size, 1)
transition = create_transition(observation, action)
batch = transition_to_batch(transition)
# Process through preprocessor
@@ -343,15 +323,11 @@ def test_classifier_processor_postprocessor_identity():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
)
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Create test data for postprocessor
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
reward = torch.tensor([[0.8], [0.3], [0.9]])
transition = create_transition(action=reward)
_ = transition_to_batch(transition)
# Process through postprocessor
@@ -1,5 +1,3 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,8 +16,8 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
from lerobot.rewards.classifier.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE, REWARD
from tests.utils import skip_if_package_missing
@@ -42,7 +40,7 @@ def test_classifier_output():
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_binary_classifier_with_default_params():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
config = RewardClassifierConfig()
config.input_features = {
@@ -86,7 +84,7 @@ def test_binary_classifier_with_default_params():
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_multiclass_classifier():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
num_classes = 5
config = RewardClassifierConfig()
@@ -128,11 +126,15 @@ def test_multiclass_classifier():
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_default_device():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
config = RewardClassifierConfig()
assert config.device == "cpu"
assert config.device is None or config.device == "cpu"
config.input_features = {
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.num_cameras = 1
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")
@@ -143,11 +145,15 @@ def test_default_device():
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_explicit_device_setup():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
config = RewardClassifierConfig(device="cpu")
assert config.device == "cpu"
config.input_features = {
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.num_cameras = 1
classifier = Classifier(config)
for p in classifier.parameters():
assert p.device == torch.device("cpu")
+373
View File
@@ -0,0 +1,373 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the reward model base classes and registry."""
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
import pytest
import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.optim.optimizers import AdamWConfig
from lerobot.rewards.pretrained import PreTrainedRewardModel
@RewardModelConfig.register_subclass(name="_dummy_hub_reward")
@dataclass
class _DummyHubRewardConfig(RewardModelConfig):
def get_optimizer_preset(self):
return AdamWConfig(lr=1e-4)
class _DummyHubReward(PreTrainedRewardModel):
config_class = _DummyHubRewardConfig
name = "_dummy_hub_reward"
def __init__(self, config):
super().__init__(config)
self.bias = torch.nn.Parameter(torch.zeros(1))
def compute_reward(self, batch):
return self.bias.expand(1)
def test_reward_model_config_registry():
"""Verify that classifier and sarm are registered."""
known = RewardModelConfig.get_known_choices()
assert "reward_classifier" in known
assert "sarm" in known
def test_reward_model_config_lookup():
"""Verify that we can look up configs by name."""
cls = RewardModelConfig.get_choice_class("reward_classifier")
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
assert cls is RewardClassifierConfig
def test_factory_get_reward_model_class():
"""Test the get_reward_model_class factory."""
from lerobot.rewards.factory import get_reward_model_class
cls = get_reward_model_class("sarm")
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
assert cls is SARMRewardModel
def test_factory_unknown_raises():
"""Unknown name should raise ValueError."""
from lerobot.rewards.factory import get_reward_model_class
with pytest.raises(ValueError, match="not available"):
get_reward_model_class("nonexistent_reward_model")
def test_pretrained_reward_model_requires_config_class():
"""Subclass without config_class should fail."""
with pytest.raises(TypeError, match="must define 'config_class'"):
class BadModel(PreTrainedRewardModel):
name = "bad"
def compute_reward(self, batch):
pass
def test_pretrained_reward_model_requires_name():
"""Subclass without name should fail."""
with pytest.raises(TypeError, match="must define 'name'"):
class BadModel(PreTrainedRewardModel):
config_class = RewardModelConfig
def compute_reward(self, batch):
pass
def test_non_trainable_forward_raises():
"""Non-trainable model should raise on forward()."""
from dataclasses import dataclass
from lerobot.optim.optimizers import AdamWConfig
@dataclass
class DummyConfig(RewardModelConfig):
def get_optimizer_preset(self):
return AdamWConfig(lr=1e-4)
class DummyReward(PreTrainedRewardModel):
config_class = DummyConfig
name = "dummy_test"
def compute_reward(self, batch):
return torch.zeros(1)
config = DummyConfig()
model = DummyReward(config)
with pytest.raises(NotImplementedError, match="not trainable"):
model.forward({"x": torch.zeros(1)})
# ---------------------------------------------------------------------------
# Trainable vs zero-shot (general-purpose) reward models.
# The proposal explicitly supports models like TOPReward that wrap a pretrained
# VLM and produce a reward signal without any training step. These tests pin
# the contract that lets such models coexist with trainable ones.
# ---------------------------------------------------------------------------
def test_is_trainable_false_when_forward_not_overridden():
"""A reward model that only implements ``compute_reward`` is zero-shot."""
model, _ = _make_dummy_reward_model()
assert model.is_trainable is False
def test_is_trainable_true_when_forward_overridden():
"""Overriding ``forward`` flips ``is_trainable`` to True."""
class _TrainableReward(_DummyHubReward):
name = "_trainable_dummy_reward"
def forward(self, batch):
loss = (self.bias**2).sum()
return loss, {}
# Register a fresh config subclass so the subclass check passes.
@RewardModelConfig.register_subclass(name="_trainable_dummy_reward")
@dataclass
class _TrainableConfig(_DummyHubRewardConfig):
pass
_TrainableReward.config_class = _TrainableConfig
model = _TrainableReward(_TrainableConfig())
assert model.is_trainable is True
# ---------------------------------------------------------------------------
# RewardModelConfig.from_pretrained
# ---------------------------------------------------------------------------
def test_reward_model_config_from_pretrained_raises_when_config_missing(tmp_path):
"""``from_pretrained`` must surface a clear ``FileNotFoundError`` when the
target directory exists but does not contain ``config.json``, instead of
crashing later inside ``draccus.parse``.
"""
# tmp_path exists but has no config.json
with pytest.raises(FileNotFoundError, match="config.json not found"):
RewardModelConfig.from_pretrained(tmp_path)
def test_reward_model_config_from_pretrained_roundtrip(tmp_path):
"""Round-trip: save a RewardClassifierConfig, reload it, fields must match."""
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
original = RewardClassifierConfig(
num_classes=3,
hidden_dim=128,
latent_dim=64,
num_cameras=1,
learning_rate=5e-4,
)
original._save_pretrained(tmp_path)
loaded = RewardModelConfig.from_pretrained(tmp_path)
assert isinstance(loaded, RewardClassifierConfig)
assert loaded.num_classes == 3
assert loaded.hidden_dim == 128
assert loaded.latent_dim == 64
assert loaded.num_cameras == 1
assert loaded.learning_rate == 5e-4
# ---------------------------------------------------------------------------
# TrainPipelineConfig — reward model training path
# ---------------------------------------------------------------------------
def test_train_pipeline_config_path_fields_includes_reward_model():
"""``--reward_model.path=local/dir`` requires ``reward_model`` to be listed
as a draccus path-field on ``TrainPipelineConfig``."""
from lerobot.configs.train import TrainPipelineConfig
fields = TrainPipelineConfig.__get_path_fields__()
assert "policy" in fields
assert "reward_model" in fields
def test_train_pipeline_config_trainable_config_returns_reward_model_when_set():
"""When only ``reward_model`` is set, ``trainable_config`` (used by the
trainer for e.g. ``.device``) must return it not ``None`` from ``policy``."""
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
reward_cfg = RewardClassifierConfig(device="cpu")
cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/repo"),
reward_model=reward_cfg,
)
assert cfg.is_reward_model_training is True
assert cfg.trainable_config is reward_cfg
# This is what lerobot_train.py uses to decide force_cpu; ``cfg.policy.device``
# would AttributeError here because policy is None.
assert cfg.trainable_config.device == "cpu"
def test_train_pipeline_config_trainable_config_returns_policy_when_set():
"""Mirror of the reward-model case: when only ``policy`` is set,
``trainable_config`` must return it."""
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
policy_cfg = DiffusionConfig(device="cpu")
cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/repo"),
policy=policy_cfg,
)
assert cfg.is_reward_model_training is False
assert cfg.trainable_config is policy_cfg
assert cfg.trainable_config.device == "cpu"
# ---------------------------------------------------------------------------
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
# We test the generation side (offline) fully, and the upload side with HfApi
# mocked so nothing actually hits the network.
# ---------------------------------------------------------------------------
def _make_dummy_reward_model(**config_kwargs):
return _DummyHubReward(_DummyHubRewardConfig(**config_kwargs)), _DummyHubRewardConfig
@pytest.fixture
def _offline_model_card(monkeypatch):
"""``ModelCard.validate`` does a live ``POST`` to huggingface.co — bypass it
so tests can run offline."""
from huggingface_hub import ModelCard
monkeypatch.setattr(ModelCard, "validate", lambda self, *a, **kw: None)
def test_reward_model_generate_model_card_renders_expected_fields(_offline_model_card):
"""``generate_model_card`` must produce a card with the right metadata and
body, using the dedicated reward-model template."""
model, _ = _make_dummy_reward_model(
license="mit",
tags=["robot", "sim"],
)
card = model.generate_model_card(
dataset_repo_id="user/my_dataset",
model_type=model.config.type,
license=model.config.license,
tags=model.config.tags,
)
# Metadata (YAML header) — ModelCardData fields.
assert card.data.license == "mit"
assert card.data.library_name == "lerobot"
assert card.data.pipeline_tag == "robotics"
assert "reward-model" in card.data.tags
assert model.config.type in card.data.tags
assert card.data.model_name == model.config.type
assert card.data.datasets == "user/my_dataset"
# Body — specific to the reward-model template, NOT the policy one.
body = str(card)
assert "Reward Model Card" in body
assert "This reward model has been trained" in body
assert "--reward_model.type=" in body # reward-model-specific usage block
def test_reward_model_generate_model_card_uses_default_license(_offline_model_card):
"""When config.license is None the card falls back to apache-2.0."""
model, _ = _make_dummy_reward_model()
card = model.generate_model_card(
dataset_repo_id="user/my_dataset",
model_type=model.config.type,
license=model.config.license,
tags=None,
)
assert card.data.license == "apache-2.0"
def test_reward_model_push_model_to_hub_uploads_expected_files(monkeypatch, _offline_model_card):
"""``push_model_to_hub`` must:
1. create the repo,
2. assemble a temp folder with weights + config.json + train_config.json + README.md,
3. call ``api.upload_folder`` on that folder.
All network calls are mocked.
"""
from huggingface_hub.constants import CONFIG_NAME
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
model, _ = _make_dummy_reward_model(
repo_id="user/my_reward",
license="apache-2.0",
)
# Point the reward model's train config at a dummy dataset repo.
train_cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/my_dataset"),
reward_model=model.config,
)
uploaded: dict = {}
fake_commit_info = SimpleNamespace(repo_url=SimpleNamespace(url="https://huggingface.co/user/my_reward"))
class _FakeHfApi:
def create_repo(self, repo_id, private=None, exist_ok=False):
uploaded["create_repo_id"] = repo_id
uploaded["create_private"] = private
return SimpleNamespace(repo_id=repo_id)
def upload_folder(self, *, repo_id, repo_type, folder_path, commit_message, **_kwargs):
uploaded["upload_repo_id"] = repo_id
uploaded["upload_repo_type"] = repo_type
uploaded["commit_message"] = commit_message
# Snapshot files assembled in the temp folder — this is the real
# contract we care about.
uploaded["files"] = sorted(p.name for p in Path(folder_path).iterdir())
return fake_commit_info
from lerobot.rewards import pretrained as reward_pretrained
monkeypatch.setattr(reward_pretrained, "HfApi", lambda *a, **kw: _FakeHfApi())
model.push_model_to_hub(train_cfg)
assert uploaded["create_repo_id"] == "user/my_reward"
assert uploaded["upload_repo_id"] == "user/my_reward"
assert uploaded["upload_repo_type"] == "model"
assert uploaded["commit_message"] == "Upload reward model weights, train config and readme"
# Minimum required files that must be uploaded with a reward model.
assert CONFIG_NAME in uploaded["files"] # config.json
assert TRAIN_CONFIG_NAME in uploaded["files"] # train_config.json
assert "README.md" in uploaded["files"]
assert any(name.endswith(".safetensors") for name in uploaded["files"])
@@ -104,8 +104,8 @@ class TestSARMEncodingProcessorStepEndToEnd:
def mock_clip_model(self):
"""Mock CLIP model to avoid loading real weights."""
with (
patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls,
patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
patch("lerobot.rewards.sarm.processor_sarm.CLIPModel") as mock_model_cls,
patch("lerobot.rewards.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
):
# Mock the CLIP model - return embeddings based on input batch size
mock_model = MagicMock()
@@ -142,7 +142,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
@pytest.fixture
def processor_with_mocks(self, mock_clip_model):
"""Create a processor with mocked CLIP and dataset metadata for dual mode."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
# Dual mode config with both sparse and dense annotations
config = MockConfig(
@@ -256,7 +256,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
def test_call_with_batched_input(self, mock_clip_model):
"""Test processor __call__ with a batched input (multiple frames) in dual mode."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
config = MockConfig(
n_obs_steps=8,
@@ -332,7 +332,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
def test_targets_increase_with_progress(self, mock_clip_model):
"""Test that both sparse and dense targets increase as frame index progresses."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
config = MockConfig(
n_obs_steps=8,
@@ -404,7 +404,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
def test_progress_labels_exact_values(self, mock_clip_model):
"""Test that progress labels (stage.tau) are computed correctly for known positions."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
# Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
config = MockConfig(
@@ -495,7 +495,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
"""Test that rewind augmentation correctly extends sequence and generates targets."""
import random
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
config = MockConfig(
n_obs_steps=8,
@@ -587,8 +587,8 @@ class TestSARMEncodingProcessorStepEndToEnd:
def test_full_sequence_target_consistency(self, mock_clip_model):
"""Test that the full sequence of targets is consistent with frame positions."""
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.policies.sarm.sarm_utils import find_stage_and_tau
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
from lerobot.rewards.sarm.sarm_utils import find_stage_and_tau
config = MockConfig(
n_obs_steps=8,
@@ -18,7 +18,7 @@ import numpy as np
import pytest
import torch
from lerobot.policies.sarm.sarm_utils import (
from lerobot.rewards.sarm.sarm_utils import (
apply_rewind_augmentation,
compute_absolute_indices,
compute_tau,
+401
View File
@@ -0,0 +1,401 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the sample weighting infrastructure."""
from unittest.mock import Mock
import pytest
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import torch
from lerobot.utils.sample_weighting import (
SampleWeighter,
SampleWeightingConfig,
UniformWeighter,
make_sample_weighter,
)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def sample_progress_parquet(tmp_path):
"""Create a sample progress parquet file for testing."""
import pandas as pd
# Create sample progress data for 2 episodes with 10 frames each
data = {
"index": list(range(20)),
"episode_index": [0] * 10 + [1] * 10,
"frame_index": list(range(10)) * 2,
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
}
df = pd.DataFrame(data)
parquet_path = tmp_path / "sarm_progress.parquet"
df.to_parquet(parquet_path)
return parquet_path
# =============================================================================
# SampleWeightingConfig Tests
# =============================================================================
def test_config_default_values():
"""Test default configuration values."""
config = SampleWeightingConfig()
assert config.type == "rabc"
assert config.progress_path is None
assert config.head_mode == "sparse"
assert config.kappa == 0.01
assert config.epsilon == 1e-6
assert config.extra_params == {}
def test_config_custom_values():
"""Test configuration with custom values."""
config = SampleWeightingConfig(
type="rabc",
progress_path="/path/to/progress.parquet",
head_mode="dense",
kappa=0.05,
epsilon=1e-8,
extra_params={"fallback_weight": 0.5},
)
assert config.type == "rabc"
assert config.progress_path == "/path/to/progress.parquet"
assert config.head_mode == "dense"
assert config.kappa == 0.05
assert config.epsilon == 1e-8
assert config.extra_params == {"fallback_weight": 0.5}
def test_config_uniform_type():
"""Test configuration for uniform weighting."""
config = SampleWeightingConfig(type="uniform")
assert config.type == "uniform"
# =============================================================================
# UniformWeighter Tests
# =============================================================================
def test_uniform_weighter_inherits_from_sample_weighter():
"""Test that UniformWeighter is a SampleWeighter."""
weighter = UniformWeighter(device=torch.device("cpu"))
assert isinstance(weighter, SampleWeighter)
def test_uniform_weighter_compute_batch_weights_with_action_key():
"""Test weight computation with 'action' key in batch."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"action": torch.randn(8, 10)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (8,)
assert torch.allclose(weights, torch.ones(8))
assert stats["mean_weight"] == 1.0
assert stats["type"] == "uniform"
def test_uniform_weighter_compute_batch_weights_with_index_key():
"""Test weight computation with 'index' key in batch."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"index": torch.arange(16)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (16,)
assert torch.allclose(weights, torch.ones(16))
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
"""Test weight computation with no tensor keys (fallback to size 1)."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"other_key": "some_value"}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (1,)
assert torch.allclose(weights, torch.ones(1))
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
"""Test that empty batch raises ValueError."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {}
with pytest.raises(ValueError, match="empty batch"):
weighter.compute_batch_weights(batch)
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
"""Test that batch size is determined by scanning all tensor values."""
weighter = UniformWeighter(device=torch.device("cpu"))
# Batch with non-standard key containing a tensor
batch = {"custom_tensor": torch.randn(7, 3)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (7,)
assert torch.allclose(weights, torch.ones(7))
def test_uniform_weighter_compute_batch_weights_on_cuda():
"""Test that weights are placed on the correct device."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
weighter = UniformWeighter(device=torch.device("cuda"))
batch = {"action": torch.randn(4, 10)}
weights, _ = weighter.compute_batch_weights(batch)
assert weights.device.type == "cuda"
def test_uniform_weighter_get_stats():
"""Test get_stats returns expected structure."""
weighter = UniformWeighter(device=torch.device("cpu"))
stats = weighter.get_stats()
assert stats == {"type": "uniform"}
# =============================================================================
# make_sample_weighter Factory Tests
# =============================================================================
def test_factory_returns_none_for_none_config():
"""Test that None config returns None weighter."""
policy = Mock()
device = torch.device("cpu")
result = make_sample_weighter(None, policy, device)
assert result is None
def test_factory_creates_uniform_weighter():
"""Test creation of UniformWeighter."""
config = SampleWeightingConfig(type="uniform")
policy = Mock()
device = torch.device("cpu")
weighter = make_sample_weighter(config, policy, device)
assert isinstance(weighter, UniformWeighter)
assert isinstance(weighter, SampleWeighter)
def test_factory_raises_for_unknown_type():
"""Test that unknown type raises ValueError."""
config = SampleWeightingConfig(type="unknown_type")
policy = Mock()
device = torch.device("cpu")
with pytest.raises(ValueError, match="Unknown sample weighting type"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_chunk_size():
"""Test that RABC weighter requires chunk_size in policy config."""
config = SampleWeightingConfig(
type="rabc",
progress_path="/path/to/progress.parquet",
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = None # No chunk_size
device = torch.device("cpu")
with pytest.raises(ValueError, match="chunk_size"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_progress_path_or_dataset_info():
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # No progress path
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 50
device = torch.device("cpu")
# Should fail when no progress_path AND no dataset info
with pytest.raises(ValueError, match="progress_path"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # Not provided, should auto-detect
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 5
device = torch.device("cpu")
# The parquet file is at sample_progress_parquet, get its parent directory
dataset_root = sample_progress_parquet.parent
weighter = make_sample_weighter(
config,
policy,
device,
dataset_root=str(dataset_root),
)
assert weighter is not None
from lerobot.rewards.sarm.rabc import RABCWeights
assert isinstance(weighter, RABCWeights)
def test_factory_rabc_auto_detects_from_repo_id():
"""Test that RABC weighter constructs HF path from repo_id."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # Not provided, should auto-detect
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 50
device = torch.device("cpu")
# This will construct the path but fail when trying to load (file doesn't exist)
# We just verify it doesn't raise the "progress_path required" error
with pytest.raises(Exception) as exc_info:
make_sample_weighter(
config,
policy,
device,
dataset_repo_id="test-user/test-dataset",
)
# Should NOT be the "progress_path required" error - it should try to load the file
assert (
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
)
# =============================================================================
# Integration Tests with RABCWeights
# =============================================================================
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
"""Test that RABCWeights inherits from SampleWeighter."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
)
assert isinstance(weighter, SampleWeighter)
def test_rabc_compute_batch_weights(sample_progress_parquet):
"""Test RABCWeights.compute_batch_weights returns correct structure."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
device=torch.device("cpu"),
)
batch = {"index": torch.tensor([0, 1, 2, 3])}
weights, stats = weighter.compute_batch_weights(batch)
assert isinstance(weights, torch.Tensor)
assert weights.shape == (4,)
assert isinstance(stats, dict)
assert "mean_weight" in stats
def test_rabc_get_stats(sample_progress_parquet):
"""Test RABCWeights.get_stats returns expected structure."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
)
stats = weighter.get_stats()
assert stats["type"] == "rabc"
assert "num_frames" in stats
assert "chunk_size" in stats
assert stats["chunk_size"] == 5
assert "head_mode" in stats
assert stats["head_mode"] == "sparse"
assert "delta_mean" in stats
assert "delta_std" in stats
def test_factory_creates_rabc_weighter(sample_progress_parquet):
"""Test factory creates RABCWeights with valid config."""
from lerobot.rewards.sarm.rabc import RABCWeights
config = SampleWeightingConfig(
type="rabc",
progress_path=str(sample_progress_parquet),
head_mode="sparse",
kappa=0.01,
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 5
device = torch.device("cpu")
weighter = make_sample_weighter(config, policy, device)
assert isinstance(weighter, RABCWeights)
assert isinstance(weighter, SampleWeighter)
def test_rabc_weights_normalization(sample_progress_parquet):
"""Test that RABCWeights normalizes weights to sum to batch_size."""
from lerobot.rewards.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
device=torch.device("cpu"),
)
batch = {"index": torch.tensor([0, 1, 2, 3])}
weights, _ = weighter.compute_batch_weights(batch)
# Weights should be normalized to sum approximately to batch_size
batch_size = 4
assert abs(weights.sum().item() - batch_size) < 0.1