diff --git a/docs/source/sarm.mdx b/docs/source/sarm.mdx index cd488fe1f..b5e63a07e 100644 --- a/docs/source/sarm.mdx +++ b/docs/source/sarm.mdx @@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros ## Inputs and Targets (What the new code expects) -SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which: +SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which: - **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features` - **Pads/truncates** robot state into `state_features` (up to `max_state_dim`) @@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict ```bash -python src/lerobot/policies/sarm/compute_rabc_weights.py \ +python -m lerobot.rewards.sarm.compute_rabc_weights \ --dataset-repo-id your-username/your-dataset \ --reward-model-path your-username/sarm-model \ --visualize-only \ @@ -360,7 +360,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \ ```bash -python src/lerobot/policies/sarm/compute_rabc_weights.py \ +python -m lerobot.rewards.sarm.compute_rabc_weights \ --dataset-repo-id your-username/your-dataset \ --reward-model-path your-username/sarm-model \ --visualize-only \ @@ -373,7 +373,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \ ```bash -python src/lerobot/policies/sarm/compute_rabc_weights.py \ +python -m lerobot.rewards.sarm.compute_rabc_weights \ --dataset-repo-id your-username/your-dataset \ --reward-model-path your-username/sarm-model \ --visualize-only \ @@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper: First, run the SARM model on all frames in your dataset to compute progress values: ```bash -python src/lerobot/policies/sarm/compute_rabc_weights.py \ +python -m lerobot.rewards.sarm.compute_rabc_weights \ --dataset-repo-id your-username/your-dataset \ --reward-model-path your-username/sarm-model \ --head-mode sparse \ @@ -465,15 +465,15 @@ This script: ### Step 5b: Train Policy with RA-BC -Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: +Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: ```bash lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=pi0 \ - --use_rabc=true \ - --rabc_head_mode=sparse \ - --rabc_kappa=0.01 \ + --sample_weighting.type=rabc \ + --sample_weighting.head_mode=sparse \ + --sample_weighting.kappa=0.01 \ --output_dir=outputs/train/policy_rabc \ --batch_size=32 \ --steps=40000 @@ -488,12 +488,13 @@ The training script automatically: **RA-BC Arguments:** -| Argument | Description | Default | -| ---------------------- | ---------------------------------------------------------- | ---------------------------------- | -| `--use_rabc` | Enable RA-BC sample weighting | `false` | -| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset | -| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` | -| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` | +| Argument | Description | Default | +| ---------------------------------- | ------------------------------------------------------ | ----------------------- | +| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` | +| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` | +| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` | +| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` | +| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` | ### Tuning RA-BC Kappa @@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we Monitor these WandB metrics during training: -| Metric | Healthy Range | Problem Indicator | -| ------------------ | ------------- | ------------------------- | -| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low | -| `rabc_delta_mean` | > 0 | Should be positive | -| `rabc_delta_std` | > 0 | Variance in data quality | +| Metric | Healthy Range | Problem Indicator | +| ----------------------------- | ------------- | ------------------------- | +| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low | +| `sample_weighting/delta_mean` | > 0 | Should be positive | +| `sample_weighting/delta_std` | > 0 | Variance in data quality | -**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC. +**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC. **Setting kappa based on your data:** -The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`: +The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`: ``` # If delta_mean ≈ 0.03 and delta_std ≈ 0.02: # Most deltas fall in range [0.01, 0.05] # Option 1: Set kappa = delta_mean (medium selectivity) ---rabc_kappa=0.03 +--sample_weighting.kappa=0.03 # Option 2: Set kappa = delta_mean + delta_std (high selectivity) ---rabc_kappa=0.05 +--sample_weighting.kappa=0.05 # Option 3: Set kappa = delta_mean + 2*delta_std (very selective) ---rabc_kappa=0.07 +--sample_weighting.kappa=0.07 ``` **When RA-BC may not help:** @@ -550,8 +551,8 @@ accelerate launch \ src/lerobot/scripts/lerobot_train.py \ --dataset.repo_id=your-username/your-dataset \ --policy.type=pi0 \ - --use_rabc=true \ - --rabc_kappa=0.01 \ + --sample_weighting.type=rabc \ + --sample_weighting.kappa=0.01 \ --output_dir=outputs/train/policy_rabc \ --batch_size=32 \ --steps=40000 @@ -576,7 +577,7 @@ accelerate launch \ ### RA-BC 1. **Train SARM first**: RA-BC quality depends entirely on SARM quality -2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa)) +2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa)) --- diff --git a/examples/dataset/slurm_compute_rabc.py b/examples/dataset/slurm_compute_rabc.py index c71d8a0fc..6d7f75f6f 100644 --- a/examples/dataset/slurm_compute_rabc.py +++ b/examples/dataset/slurm_compute_rabc.py @@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep): import torch from tqdm import tqdm - from lerobot.policies.sarm.compute_rabc_weights import ( + from lerobot.rewards.sarm.compute_rabc_weights import ( generate_all_frame_indices, interpolate_progress, load_sarm_resources, diff --git a/examples/tutorial/rl/hilserl_example.py b/examples/tutorial/rl/hilserl_example.py index 8a08d6d56..71b50e97c 100644 --- a/examples/tutorial/rl/hilserl_example.py +++ b/examples/tutorial/rl/hilserl_example.py @@ -10,7 +10,7 @@ from lerobot.datasets import LeRobotDataset from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig from lerobot.policies import SACConfig from lerobot.policies.sac.modeling_sac import SACPolicy -from lerobot.policies.sac.reward_model.modeling_classifier import Classifier +from lerobot.rewards.classifier.modeling_classifier import Classifier from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.gym_manipulator import make_robot_env from lerobot.robots.so_follower import SO100FollowerConfig diff --git a/examples/tutorial/rl/reward_classifier_example.py b/examples/tutorial/rl/reward_classifier_example.py index b386bf4db..ddecfbcfc 100644 --- a/examples/tutorial/rl/reward_classifier_example.py +++ b/examples/tutorial/rl/reward_classifier_example.py @@ -1,7 +1,7 @@ import torch from lerobot.datasets import LeRobotDataset -from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors +from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors def main(): @@ -22,10 +22,10 @@ def main(): model_name="microsoft/resnet-18", ) - # Make policy, preprocessor, and optimizer - policy = make_policy(config, ds_meta=dataset.meta) - optimizer = config.get_optimizer_preset().build(policy.parameters()) - preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats) + # Make reward model, preprocessor, and optimizer + reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats) + optimizer = config.get_optimizer_preset().build(reward_model.parameters()) + preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats) classifier_id = "/reward_classifier_hil_serl_example" @@ -42,7 +42,7 @@ def main(): batch = preprocessor(batch) # Forward pass - loss, output_dict = policy.forward(batch) + loss, output_dict = reward_model.forward(batch) # Backward pass and optimization optimizer.zero_grad() @@ -58,8 +58,8 @@ def main(): print("Training finished!") - # You can now save the trained policy. - policy.push_to_hub(classifier_id) + # You can now save the trained reward model. + reward_model.push_to_hub(classifier_id) if __name__ == "__main__": diff --git a/src/lerobot/common/wandb_utils.py b/src/lerobot/common/wandb_utils.py index e3190b6ce..b782cd751 100644 --- a/src/lerobot/common/wandb_utils.py +++ b/src/lerobot/common/wandb_utils.py @@ -41,8 +41,12 @@ def cfg_to_group( return tag return tag[:max_tag_length] + if cfg.is_reward_model_training: + trainable_tag = f"reward_model:{cfg.reward_model.type}" + else: + trainable_tag = f"policy:{cfg.policy.type}" lst = [ - f"policy:{cfg.policy.type}", + trainable_tag, f"seed:{cfg.seed}", ] if cfg.dataset is not None: diff --git a/src/lerobot/configs/rewards.py b/src/lerobot/configs/rewards.py new file mode 100644 index 000000000..d495160bf --- /dev/null +++ b/src/lerobot/configs/rewards.py @@ -0,0 +1,163 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import builtins +import json +import logging +import os +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, TypeVar + +import draccus +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import CONFIG_NAME +from huggingface_hub.errors import HfHubHTTPError + +from lerobot.configs.types import PolicyFeature +from lerobot.optim.optimizers import OptimizerConfig +from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available +from lerobot.utils.hub import HubMixin + +T = TypeVar("T", bound="RewardModelConfig") +logger = logging.getLogger(__name__) + + +@dataclass +class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): + """Base configuration for reward models. + + Args: + input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + """ + + # Reuses PolicyFeature + input_features: dict[str, PolicyFeature] = field(default_factory=dict) + output_features: dict[str, PolicyFeature] = field(default_factory=dict) + + device: str | None = None + + pretrained_path: str | None = None + + push_to_hub: bool = False + repo_id: str | None = None + + # Hub metadata + license: str | None = None + tags: list[str] | None = None + private: bool | None = None + + def __post_init__(self) -> None: + if not self.device or not is_torch_device_available(self.device): + auto_device = auto_select_torch_device() + logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") + self.device = auto_device.type + + @property + def type(self) -> str: + choice_name = self.get_choice_name(self.__class__) + if not isinstance(choice_name, str): + raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}") + return choice_name + + @property + def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] + return None + + @property + def action_delta_indices(self) -> list | None: # type: ignore[type-arg] + return None + + @property + def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] + return None + + @abc.abstractmethod + def get_optimizer_preset(self) -> OptimizerConfig: + raise NotImplementedError + + def get_scheduler_preset(self) -> LRSchedulerConfig | None: + return None + + def validate_features(self) -> None: + pass + + def _save_pretrained(self, save_directory: Path) -> None: + with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"): + draccus.dump(self, f, indent=4) + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict[Any, Any] | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + **reward_kwargs: Any, + ) -> T: + model_id = str(pretrained_name_or_path) + config_file: str | None = None + if Path(model_id).is_dir(): + if CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, CONFIG_NAME) + else: + logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" + ) from e + + if config_file is None: + raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}") + + # HACK: Parse the original config to get the config subclass, so that we can + # apply cli overrides. + with draccus.config_type("json"): + orig_config = draccus.parse(cls, config_file, args=[]) + + with open(config_file) as f: + config = json.load(f) + + config.pop("type", None) + with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f: + json.dump(config, f) + config_file = f.name + + cli_overrides = reward_kwargs.pop("cli_overrides", []) + with draccus.config_type("json"): + return draccus.parse(orig_config.__class__, config_file, args=cli_overrides) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 924bcf5bb..3f78cc07b 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -26,9 +26,11 @@ from lerobot import envs from lerobot.configs import parser from lerobot.optim import LRSchedulerConfig, OptimizerConfig from lerobot.utils.hub import HubMixin +from lerobot.utils.sample_weighting import SampleWeightingConfig from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig from .policies import PreTrainedConfig +from .rewards import RewardModelConfig TRAIN_CONFIG_NAME = "train_config.json" @@ -38,6 +40,7 @@ class TrainPipelineConfig(HubMixin): dataset: DatasetConfig env: envs.EnvConfig | None = None policy: PreTrainedConfig | None = None + reward_model: RewardModelConfig | None = None # Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. output_dir: Path | None = None @@ -72,27 +75,41 @@ class TrainPipelineConfig(HubMixin): wandb: WandBConfig = field(default_factory=WandBConfig) peft: PeftConfig | None = None - # RA-BC (Reward-Aligned Behavior Cloning) parameters - use_rabc: bool = False # Enable reward-weighted training - rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file - rabc_kappa: float = 0.01 # Hard threshold for high-quality samples - rabc_epsilon: float = 1e-6 # Small constant for numerical stability - rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense" + # Sample weighting configuration (e.g., for RA-BC training) + sample_weighting: SampleWeightingConfig | None = None # Rename map for the observation to override the image and state keys rename_map: dict[str, str] = field(default_factory=dict) checkpoint_path: Path | None = field(init=False, default=None) + @property + def is_reward_model_training(self) -> bool: + """True when the config targets a reward model rather than a policy.""" + return self.reward_model is not None + + @property + def trainable_config(self) -> PreTrainedConfig | RewardModelConfig: + """Return whichever config (policy or reward_model) is active.""" + if self.is_reward_model_training: + return self.reward_model # type: ignore[return-value] + return self.policy # type: ignore[return-value] + def validate(self) -> None: # HACK: We parse again the cli args here to get the pretrained paths if there was some. policy_path = parser.get_path_arg("policy") - if policy_path: - # Only load the policy config + reward_model_path = parser.get_path_arg("reward_model") + + if reward_model_path: + cli_overrides = parser.get_cli_overrides("reward_model") + self.reward_model = RewardModelConfig.from_pretrained( + reward_model_path, cli_overrides=cli_overrides + ) + self.reward_model.pretrained_path = str(Path(reward_model_path)) + elif policy_path: cli_overrides = parser.get_cli_overrides("policy") self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = Path(policy_path) elif self.resume: - # The entire train config is already loaded, we just need to get the checkpoint dir config_path = parser.parse_arg("config_path") if not config_path: raise ValueError( @@ -108,18 +125,22 @@ class TrainPipelineConfig(HubMixin): policy_dir = Path(config_path).parent if self.policy is not None: self.policy.pretrained_path = policy_dir + if self.reward_model is not None: + self.reward_model.pretrained_path = str(policy_dir) self.checkpoint_path = policy_dir.parent - if self.policy is None: + if self.policy is None and self.reward_model is None: raise ValueError( - "Policy is not configured. Please specify a pretrained policy with `--policy.path`." + "Neither policy nor reward_model is configured. " + "Please specify one with `--policy.path` or `--reward_model.path`." ) + active_cfg = self.trainable_config if not self.job_name: if self.env is None: - self.job_name = f"{self.policy.type}" + self.job_name = f"{active_cfg.type}" else: - self.job_name = f"{self.env.type}_{self.policy.type}" + self.job_name = f"{self.env.type}_{active_cfg.type}" if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir(): raise FileExistsError( @@ -137,26 +158,16 @@ class TrainPipelineConfig(HubMixin): if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") elif self.use_policy_training_preset and not self.resume: - self.optimizer = self.policy.get_optimizer_preset() - self.scheduler = self.policy.get_scheduler_preset() + self.optimizer = active_cfg.get_optimizer_preset() + self.scheduler = active_cfg.get_scheduler_preset() - if self.policy.push_to_hub and not self.policy.repo_id: - raise ValueError( - "'policy.repo_id' argument missing. Please specify it to push the model to the hub." - ) - - if self.use_rabc and not self.rabc_progress_path: - # Auto-detect from dataset path - repo_id = self.dataset.repo_id - if self.dataset.root: - self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet") - else: - self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet" + if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id: + raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.") @classmethod def __get_path_fields__(cls) -> list[str]: - """This enables the parser to load config from the policy using `--policy.path=local/dir`""" - return ["policy"] + """Keys for draccus pretrained-path loading.""" + return ["policy", "reward_model"] def to_dict(self) -> dict[str, Any]: return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 73df3f04b..cbbe83dc8 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -19,6 +19,7 @@ from pprint import pformat import torch from lerobot.configs import PreTrainedConfig +from lerobot.configs.rewards import RewardModelConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.transforms import ImageTransforms from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD @@ -30,12 +31,14 @@ from .streaming_dataset import StreamingLeRobotDataset def resolve_delta_timestamps( - cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata + cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata ) -> dict[str, list] | None: - """Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig. + """Resolves delta_timestamps by reading from the 'delta_indices' properties of the config. Args: - cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from. + cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both + ``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the + ``{observation,action,reward}_delta_indices`` properties used below. ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build delta_timestamps against. @@ -82,7 +85,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas ds_meta = LeRobotDatasetMetadata( cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision ) - delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) + delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta) if not cfg.dataset.streaming: dataset = LeRobotDataset( cfg.dataset.repo_id, diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 905276642..e811eef28 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -24,8 +24,6 @@ from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig from .pi05.configuration_pi05 import PI05Config as PI05Config from .pretrained import PreTrainedPolicy as PreTrainedPolicy from .sac.configuration_sac import SACConfig as SACConfig -from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig -from .sarm.configuration_sarm import SARMConfig as SARMConfig from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .utils import make_robot_action, prepare_observation_for_inference @@ -46,9 +44,7 @@ __all__ = [ "PI0Config", "PI0FastConfig", "PI05Config", - "RewardClassifierConfig", "SACConfig", - "SARMConfig", "SmolVLAConfig", "TDMPCConfig", "VQBeTConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 611a6e9bc..5be3bca43 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -52,8 +52,6 @@ from .pi0.configuration_pi0 import PI0Config from .pi05.configuration_pi05 import PI05Config from .pretrained import PreTrainedPolicy from .sac.configuration_sac import SACConfig -from .sac.reward_model.configuration_classifier import RewardClassifierConfig -from .sarm.configuration_sarm import SARMConfig from .smolvla.configuration_smolvla import SmolVLAConfig from .tdmpc.configuration_tdmpc import TDMPCConfig from .utils import validate_visual_features_consistency @@ -89,7 +87,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x". + "multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x". Returns: The policy class corresponding to the given name. @@ -132,18 +130,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from .sac.modeling_sac import SACPolicy return SACPolicy - elif name == "reward_classifier": - from .sac.reward_model.modeling_classifier import Classifier - - return Classifier elif name == "smolvla": from .smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy - elif name == "sarm": - from .sarm.modeling_sarm import SARMRewardModel - - return SARMRewardModel elif name == "groot": from .groot.modeling_groot import GrootPolicy @@ -173,7 +163,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", - "smolvla", "reward_classifier", "wall_x". + "smolvla", "wall_x". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -200,8 +190,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return SACConfig(**kwargs) elif policy_type == "smolvla": return SmolVLAConfig(**kwargs) - elif policy_type == "reward_classifier": - return RewardClassifierConfig(**kwargs) elif policy_type == "groot": return GrootConfig(**kwargs) elif policy_type == "xvla": @@ -378,14 +366,6 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) - elif isinstance(policy_cfg, RewardClassifierConfig): - from .sac.reward_model.processor_classifier import make_classifier_processor - - processors = make_classifier_processor( - config=policy_cfg, - dataset_stats=kwargs.get("dataset_stats"), - ) - elif isinstance(policy_cfg, SmolVLAConfig): from .smolvla.processor_smolvla import make_smolvla_pre_post_processors @@ -394,14 +374,6 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) - elif isinstance(policy_cfg, SARMConfig): - from .sarm.processor_sarm import make_sarm_pre_post_processors - - processors = make_sarm_pre_post_processors( - config=policy_cfg, - dataset_stats=kwargs.get("dataset_stats"), - dataset_meta=kwargs.get("dataset_meta"), - ) elif isinstance(policy_cfg, GrootConfig): from .groot.processor_groot import make_groot_pre_post_processors diff --git a/src/lerobot/policies/sarm/README.md b/src/lerobot/policies/sarm/README.md deleted file mode 120000 index 18495860d..000000000 --- a/src/lerobot/policies/sarm/README.md +++ /dev/null @@ -1 +0,0 @@ -../../../../docs/source/policy_sarm_README.md \ No newline at end of file diff --git a/src/lerobot/policies/sarm/__init__.py b/src/lerobot/policies/sarm/__init__.py deleted file mode 100644 index b164c87ef..000000000 --- a/src/lerobot/policies/sarm/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .configuration_sarm import SARMConfig -from .modeling_sarm import SARMRewardModel - -__all__ = ["SARMConfig", "SARMRewardModel"] diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index c6f98c689..49dbb8106 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -557,7 +557,7 @@ class RewardClassifierProcessorStep(ProcessorStep): def __post_init__(self): """Initializes the reward classifier model after the dataclass is created.""" if self.pretrained_path is not None: - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier + from lerobot.rewards.classifier.modeling_classifier import Classifier self.reward_classifier = Classifier.from_pretrained(self.pretrained_path) self.reward_classifier.to(self.device) diff --git a/src/lerobot/rewards/__init__.py b/src/lerobot/rewards/__init__.py new file mode 100644 index 000000000..203fe2ee1 --- /dev/null +++ b/src/lerobot/rewards/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig +from .factory import ( + get_reward_model_class as get_reward_model_class, + make_reward_model as make_reward_model, + make_reward_model_config as make_reward_model_config, + make_reward_pre_post_processors as make_reward_pre_post_processors, +) +from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel +from .sarm.configuration_sarm import SARMConfig as SARMConfig + +__all__ = [ + # Configuration classes + "RewardClassifierConfig", + "SARMConfig", + # Base class + "PreTrainedRewardModel", + # Factory functions + "get_reward_model_class", + "make_reward_model", + "make_reward_model_config", + "make_reward_pre_post_processors", +] diff --git a/src/lerobot/policies/sac/reward_model/__init__.py b/src/lerobot/rewards/classifier/__init__.py similarity index 100% rename from src/lerobot/policies/sac/reward_model/__init__.py rename to src/lerobot/rewards/classifier/__init__.py diff --git a/src/lerobot/policies/sac/reward_model/configuration_classifier.py b/src/lerobot/rewards/classifier/configuration_classifier.py similarity index 92% rename from src/lerobot/policies/sac/reward_model/configuration_classifier.py rename to src/lerobot/rewards/classifier/configuration_classifier.py index 3a5bfa424..a618a2cf7 100644 --- a/src/lerobot/policies/sac/reward_model/configuration_classifier.py +++ b/src/lerobot/rewards/classifier/configuration_classifier.py @@ -1,5 +1,3 @@ -# !/usr/bin/env python - # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,14 +13,15 @@ # limitations under the License. from dataclasses import dataclass, field -from lerobot.configs import NormalizationMode, PreTrainedConfig +from lerobot.configs import NormalizationMode +from lerobot.configs.rewards import RewardModelConfig from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig from lerobot.utils.constants import OBS_IMAGE -@PreTrainedConfig.register_subclass(name="reward_classifier") +@RewardModelConfig.register_subclass(name="reward_classifier") @dataclass -class RewardClassifierConfig(PreTrainedConfig): +class RewardClassifierConfig(RewardModelConfig): """Configuration for the Reward Classifier model.""" name: str = "reward_classifier" diff --git a/src/lerobot/policies/sac/reward_model/modeling_classifier.py b/src/lerobot/rewards/classifier/modeling_classifier.py similarity index 88% rename from src/lerobot/policies/sac/reward_model/modeling_classifier.py rename to src/lerobot/rewards/classifier/modeling_classifier.py index c8b7efe58..bedfffbe9 100644 --- a/src/lerobot/policies/sac/reward_model/modeling_classifier.py +++ b/src/lerobot/rewards/classifier/modeling_classifier.py @@ -1,5 +1,3 @@ -# !/usr/bin/env python - # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,11 +17,10 @@ import logging import torch from torch import Tensor, nn +from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig +from lerobot.rewards.pretrained import PreTrainedRewardModel from lerobot.utils.constants import OBS_IMAGE, REWARD -from ...pretrained import PreTrainedPolicy -from .configuration_classifier import RewardClassifierConfig - class ClassifierOutput: """Wrapper for classifier outputs with additional metadata.""" @@ -99,7 +96,7 @@ class SpatialLearnedEmbeddings(nn.Module): return output -class Classifier(PreTrainedPolicy): +class Classifier(PreTrainedRewardModel): """Image classifier built on top of a pre-trained encoder.""" name = "reward_classifier" @@ -235,6 +232,16 @@ class Classifier(PreTrainedPolicy): return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) + def compute_reward(self, batch: dict[str, Tensor]) -> Tensor: + """Returns 1.0 for success, 0.0 for failure based on image observations.""" + images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] + output = self.predict(images) + + if self.config.num_classes == 2: + return (output.probabilities > 0.5).float() + else: + return torch.argmax(output.probabilities, dim=1).float() + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: """Standard forward pass for training compatible with train.py.""" # Extract images and labels @@ -269,10 +276,6 @@ class Classifier(PreTrainedPolicy): def predict_reward(self, batch, threshold=0.5): """Eval method. Returns predicted reward with the decision threshold as argument.""" - # Check for both OBS_IMAGE and OBS_IMAGES prefixes - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) - # Extract images from batch dict images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] @@ -282,28 +285,3 @@ class Classifier(PreTrainedPolicy): return (probs > threshold).float() else: return torch.argmax(self.predict(images).probabilities, dim=1) - - def get_optim_params(self): - """Return optimizer parameters for the policy.""" - return self.parameters() - - def select_action(self, batch: dict[str, Tensor]) -> Tensor: - """ - This method is required by PreTrainedPolicy but not used for reward classifiers. - The reward classifier is not an actor and does not select actions. - """ - raise NotImplementedError("Reward classifiers do not select actions") - - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: - """ - This method is required by PreTrainedPolicy but not used for reward classifiers. - The reward classifier is not an actor and does not produce action chunks. - """ - raise NotImplementedError("Reward classifiers do not predict action chunks") - - def reset(self): - """ - This method is required by PreTrainedPolicy but not used for reward classifiers. - The reward classifier is not an actor and does not select actions. - """ - pass diff --git a/src/lerobot/policies/sac/reward_model/processor_classifier.py b/src/lerobot/rewards/classifier/processor_classifier.py similarity index 91% rename from src/lerobot/policies/sac/reward_model/processor_classifier.py rename to src/lerobot/rewards/classifier/processor_classifier.py index 1f7a66e58..056d7e91b 100644 --- a/src/lerobot/policies/sac/reward_model/processor_classifier.py +++ b/src/lerobot/rewards/classifier/processor_classifier.py @@ -1,5 +1,3 @@ -# !/usr/bin/env python - # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,8 +25,7 @@ from lerobot.processor import ( policy_action_to_transition, transition_to_policy_action, ) - -from .configuration_classifier import RewardClassifierConfig +from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig def make_classifier_processor( @@ -52,8 +49,6 @@ def make_classifier_processor( Args: config: The configuration object for the RewardClassifier. dataset_stats: A dictionary of statistics for normalization. - preprocessor_kwargs: Additional arguments for the pre-processor pipeline. - postprocessor_kwargs: Additional arguments for the post-processor pipeline. Returns: A tuple containing the configured pre-processor and post-processor pipelines. diff --git a/src/lerobot/rewards/factory.py b/src/lerobot/rewards/factory.py new file mode 100644 index 000000000..f6716f3fb --- /dev/null +++ b/src/lerobot/rewards/factory.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import logging +from typing import Any + +import torch + +from lerobot.configs.rewards import RewardModelConfig +from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig +from lerobot.rewards.pretrained import PreTrainedRewardModel +from lerobot.rewards.sarm.configuration_sarm import SARMConfig + + +def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]: + """ + Retrieves a reward model class by its registered name. + + This function uses dynamic imports to avoid loading all reward model classes into + memory at once, improving startup time and reducing dependencies. + + Args: + name: The name of the reward model. Supported names are "reward_classifier", + "sarm". + + Returns: + The reward model class corresponding to the given name. + + Raises: + ValueError: If the reward model name is not recognized. + """ + if name == "reward_classifier": + from lerobot.rewards.classifier.modeling_classifier import Classifier + + return Classifier + elif name == "sarm": + from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel + + return SARMRewardModel + else: + try: + return _get_reward_model_cls_from_name(name=name) + except Exception as e: + raise ValueError(f"Reward model type '{name}' is not available.") from e + + +def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig: + """ + Instantiates a reward model configuration object based on the reward type. + + This factory function simplifies the creation of reward model configuration objects + by mapping a string identifier to the corresponding config class. + + Args: + reward_type: The type of the reward model. Supported types include + "reward_classifier", "sarm". + **kwargs: Keyword arguments to be passed to the configuration class constructor. + + Returns: + An instance of a `RewardModelConfig` subclass. + + Raises: + ValueError: If the `reward_type` is not recognized. + """ + if reward_type == "reward_classifier": + return RewardClassifierConfig(**kwargs) + elif reward_type == "sarm": + return SARMConfig(**kwargs) + else: + try: + config_cls = RewardModelConfig.get_choice_class(reward_type) + return config_cls(**kwargs) + except Exception as e: + raise ValueError(f"Reward model type '{reward_type}' is not available.") from e + + +def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel: + """ + Instantiate a reward model from its configuration. + + Args: + cfg: The configuration for the reward model to be created. If + `cfg.pretrained_path` is set, the model will be loaded with weights + from that path. + **kwargs: Additional keyword arguments forwarded to the model constructor + (e.g., ``dataset_stats``, ``dataset_meta``). + + Returns: + An instantiated and device-placed reward model. + """ + reward_cls = get_reward_model_class(cfg.type) + + kwargs["config"] = cfg + + if cfg.pretrained_path: + kwargs["pretrained_name_or_path"] = cfg.pretrained_path + reward_model = reward_cls.from_pretrained(**kwargs) + else: + reward_model = reward_cls(**kwargs) + + reward_model.to(cfg.device) + assert isinstance(reward_model, torch.nn.Module) + + return reward_model + + +def make_reward_pre_post_processors( + reward_cfg: RewardModelConfig, + **kwargs, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Create pre- and post-processor pipelines for a given reward model. + + Each reward model type has a dedicated factory function for its processors. + + Args: + reward_cfg: The configuration of the reward model for which to create processors. + **kwargs: Additional keyword arguments passed to the processor factory + (e.g., ``dataset_stats``, ``dataset_meta``). + + Returns: + A tuple containing the input (pre-processor) and output (post-processor) pipelines. + + Raises: + ValueError: If a processor factory is not implemented for the given reward + model configuration type. + """ + # Create a new processor based on reward model type + if isinstance(reward_cfg, RewardClassifierConfig): + from lerobot.rewards.classifier.processor_classifier import make_classifier_processor + + return make_classifier_processor( + config=reward_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(reward_cfg, SARMConfig): + from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors + + return make_sarm_pre_post_processors( + config=reward_cfg, + dataset_stats=kwargs.get("dataset_stats"), + dataset_meta=kwargs.get("dataset_meta"), + ) + + else: + try: + processors = _make_processors_from_reward_model_config( + config=reward_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + except Exception as e: + raise ValueError( + f"Processor for reward model type '{reward_cfg.type}' is not implemented." + ) from e + return processors + + +def _get_reward_model_cls_from_name(name: str) -> type[PreTrainedRewardModel]: + """Get reward model class from its registered name using dynamic imports. + + This is used as a helper function to import reward models from 3rd party lerobot + plugins. + + Args: + name: The name of the reward model. + + Returns: + The reward model class corresponding to the given name. + """ + if name not in RewardModelConfig.get_known_choices(): + raise ValueError( + f"Unknown reward model name '{name}'. " + f"Available reward models: {RewardModelConfig.get_known_choices()}" + ) + + config_cls = RewardModelConfig.get_choice_class(name) + config_cls_name = config_cls.__name__ + + model_name = config_cls_name.removesuffix("Config") + if model_name == config_cls_name: + raise ValueError( + f"The config class name '{config_cls_name}' does not follow the expected naming convention. " + f"Make sure it ends with 'Config'!" + ) + + cls_name = model_name + "RewardModel" + module_path = config_cls.__module__.replace("configuration_", "modeling_") + + module = importlib.import_module(module_path) + reward_cls = getattr(module, cls_name) + return reward_cls + + +def _make_processors_from_reward_model_config( + config: RewardModelConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[Any, Any]: + """Create pre- and post-processors from a reward model configuration using dynamic imports. + + This is used as a helper function to import processor factories from 3rd party + lerobot reward model plugins. + + Args: + config: The reward model configuration object. + dataset_stats: Dataset statistics for normalization. + + Returns: + A tuple containing the input (pre-processor) and output (post-processor) pipelines. + """ + reward_type = config.type + function_name = f"make_{reward_type}_pre_post_processors" + module_path = config.__class__.__module__.replace("configuration_", "processor_") + logging.debug( + f"Instantiating reward pre/post processors using function '{function_name}' " + f"from module '{module_path}'" + ) + module = importlib.import_module(module_path) + function = getattr(module, function_name) + return function(config, dataset_stats=dataset_stats) diff --git a/src/lerobot/rewards/pretrained.py b/src/lerobot/rewards/pretrained.py new file mode 100644 index 000000000..d44b31733 --- /dev/null +++ b/src/lerobot/rewards/pretrained.py @@ -0,0 +1,244 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import builtins +import logging +import os +from importlib.resources import files +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, TypeVar + +import packaging +import safetensors +from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE +from huggingface_hub.errors import HfHubHTTPError +from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor +from torch import Tensor, nn + +from lerobot.configs.rewards import RewardModelConfig +from lerobot.utils.hub import HubMixin + +if TYPE_CHECKING: + from lerobot.configs.train import TrainPipelineConfig + +T = TypeVar("T", bound="PreTrainedRewardModel") + + +class PreTrainedRewardModel(nn.Module, HubMixin, abc.ABC): + """Base class for reward models.""" + + config_class: None + name: None + + def __init__(self, config: RewardModelConfig, *inputs, **kwargs): + super().__init__() + if not isinstance(config, RewardModelConfig): + raise ValueError( + f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " + "`RewardModelConfig`. To create a model from a pretrained model use " + f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.config = config + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not getattr(cls, "config_class", None): + raise TypeError(f"Class {cls.__name__} must define 'config_class'") + if not getattr(cls, "name", None): + raise TypeError(f"Class {cls.__name__} must define 'name'") + + def _save_pretrained(self, save_directory: Path) -> None: + self.config._save_pretrained(save_directory) + model_to_save = self.module if hasattr(self, "module") else self + save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: RewardModelConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = False, + **kwargs, + ) -> T: + """ + The reward model is set in evaluation mode by default using `reward.eval()` (dropout modules are + deactivated). To train it, you should first set it back in training mode with `reward.train()`. + """ + if config is None: + config = RewardModelConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + model_id = str(pretrained_name_or_path) + instance = cls(config, **kwargs) + if os.path.isdir(model_id): + print("Loading weights from local directory") + model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) + reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict) + else: + try: + model_file = hf_hub_download( + repo_id=model_id, + filename=SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" + ) from e + + reward.to(config.device) + reward.eval() + return reward + + @classmethod + def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + # Create base kwargs + kwargs = {"strict": strict} + + # Add device parameter for newer versions that support it + if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"): + kwargs["device"] = map_location + + # Load the model with appropriate kwargs + missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs) + if missing_keys: + logging.warning(f"Missing key(s) when loading model: {missing_keys}") + if unexpected_keys: + logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}") + + # For older versions, manually move to device if needed + if "device" not in kwargs and map_location != "cpu": + logging.warning( + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." + " This means that the model is loaded on 'cpu' first and then copied to the device." + " This leads to a slower loading time." + " Please update safetensors to version 0.4.3 or above for improved performance." + ) + model.to(map_location) + return model + + def get_optim_params(self): + """ + Returns the reward-model-specific parameters dict to be passed on to the optimizer. + """ + return self.parameters() + + def reset(self) -> None: + """Reset any internal state.""" + pass + + @abc.abstractmethod + def compute_reward(self, batch: dict[str, Tensor]) -> Tensor: + """Compute a scalar reward signal for a batch of observations. + + Args: + batch: Dictionary containing at minimum observation tensors. + May also contain "action", "next_observation.*", etc. + + Returns: + Tensor of shape ``(batch_size,)`` with reward values. + """ + ... + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]: + """Training forward pass — override for trainable reward models.""" + raise NotImplementedError( + f"{self.__class__.__name__} is not trainable. Only use compute_reward() for inference." + ) + + @property + def is_trainable(self) -> bool: + """Whether this reward model can be trained via ``lerobot-train``. + + Trainable reward models override :meth:`forward`; zero-shot models + inherit the base implementation that raises ``NotImplementedError``. + """ + return type(self).forward is not PreTrainedRewardModel.forward + + def push_model_to_hub(self, cfg: "TrainPipelineConfig"): + api = HfApi() + repo_id = api.create_repo( + repo_id=self.config.repo_id, private=self.config.private, exist_ok=True + ).repo_id + + # Push the files to the repo in a single commit + with TemporaryDirectory(ignore_cleanup_errors=True) as tmp: + saved_path = Path(tmp) / repo_id + + self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors + + card = self.generate_model_card( + cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags + ) + card.save(str(saved_path / "README.md")) + + cfg.save_pretrained(saved_path) # Calls _save_pretrained and stores train config + + commit_info = api.upload_folder( + repo_id=repo_id, + repo_type="model", + folder_path=saved_path, + commit_message="Upload reward model weights, train config and readme", + allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.md"], + ignore_patterns=["*.tmp", "*.log"], + ) + + logging.info(f"Model pushed to {commit_info.repo_url.url}") + + def generate_model_card( + self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None + ) -> ModelCard: + card_data = ModelCardData( + license=license or "apache-2.0", + library_name="lerobot", + pipeline_tag="robotics", + tags=list(set(tags or []).union({"robotics", "lerobot", "reward-model", model_type})), + model_name=model_type, + datasets=dataset_repo_id, + ) + + template_card = ( + files("lerobot.templates") + .joinpath("lerobot_rewardmodel_modelcard_template.md") + .read_text(encoding="utf-8") + ) + card = ModelCard.from_template(card_data, template_str=template_card) + card.validate() + return card diff --git a/src/lerobot/rewards/sarm/__init__.py b/src/lerobot/rewards/sarm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/lerobot/policies/sarm/compute_rabc_weights.py b/src/lerobot/rewards/sarm/compute_rabc_weights.py similarity index 98% rename from src/lerobot/policies/sarm/compute_rabc_weights.py rename to src/lerobot/rewards/sarm/compute_rabc_weights.py index 07d0780b5..b1bf2e1f5 100644 --- a/src/lerobot/policies/sarm/compute_rabc_weights.py +++ b/src/lerobot/rewards/sarm/compute_rabc_weights.py @@ -25,18 +25,18 @@ need ~num_frames/30 queries instead of one per frame (~30x speedup). Usage: # Full RA-BC computation with visualizations - python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + python src/lerobot/rewards/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ --reward-model-path /sarm_single_uni4 # Faster computation with stride (compute every 5 frames, interpolate the rest) - python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + python src/lerobot/rewards/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ --reward-model-path /sarm_single_uni4 \\ --stride 5 # Visualize predictions only (no RA-BC computation) - python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + python src/lerobot/rewards/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ --reward-model-path /sarm_single_uni4 \\ --visualize-only \\ @@ -58,10 +58,9 @@ import torch from tqdm import tqdm from lerobot.datasets import LeRobotDataset - -from .modeling_sarm import SARMRewardModel -from .processor_sarm import make_sarm_pre_post_processors -from .sarm_utils import normalize_stage_tau +from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel +from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors +from lerobot.rewards.sarm.sarm_utils import normalize_stage_tau def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None: @@ -713,12 +712,12 @@ def main(): epilog=""" Examples: # Full RA-BC computation with visualizations - python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + python src/lerobot/rewards/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ --reward-model-path /sarm_single_uni4 # Visualize predictions only (no RA-BC computation) - python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + python src/lerobot/rewards/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ --reward-model-path /sarm_single_uni4 \\ --visualize-only \\ diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/rewards/sarm/configuration_sarm.py similarity index 98% rename from src/lerobot/policies/sarm/configuration_sarm.py rename to src/lerobot/rewards/sarm/configuration_sarm.py index fc8daa055..0d1f727f7 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/rewards/sarm/configuration_sarm.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu # and The HuggingFace Inc. team. All rights reserved. # @@ -22,14 +20,15 @@ Paper: https://arxiv.org/abs/2509.25358 from dataclasses import dataclass, field -from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.rewards import RewardModelConfig from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import OBS_IMAGES, OBS_STATE -@PreTrainedConfig.register_subclass("sarm") +@RewardModelConfig.register_subclass("sarm") @dataclass -class SARMConfig(PreTrainedConfig): +class SARMConfig(RewardModelConfig): """Configuration class for SARM (Stage-Aware Reward Modeling). Supports three annotation modes: @@ -110,7 +109,6 @@ class SARMConfig(PreTrainedConfig): def __post_init__(self): super().__post_init__() - if self.annotation_mode not in ["single_stage", "dense_only", "dual"]: raise ValueError( f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}" diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/rewards/sarm/modeling_sarm.py similarity index 96% rename from src/lerobot/policies/sarm/modeling_sarm.py rename to src/lerobot/rewards/sarm/modeling_sarm.py index 710554e4b..365f519b2 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/rewards/sarm/modeling_sarm.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu # and The HuggingFace Inc. team. All rights reserved. # @@ -34,14 +32,13 @@ import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor -from lerobot.utils.constants import OBS_STR - -from ..pretrained import PreTrainedPolicy -from .configuration_sarm import SARMConfig -from .sarm_utils import ( +from lerobot.rewards.pretrained import PreTrainedRewardModel +from lerobot.rewards.sarm.configuration_sarm import SARMConfig +from lerobot.rewards.sarm.sarm_utils import ( normalize_stage_tau, pad_state_to_max_dim, ) +from lerobot.utils.constants import OBS_STR class StageTransformer(nn.Module): @@ -353,7 +350,7 @@ def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor: return stage_onehot -class SARMRewardModel(PreTrainedPolicy): +class SARMRewardModel(PreTrainedRewardModel): """ SARM Reward Model for stage-aware task completion rewards. @@ -471,6 +468,23 @@ class SARMRewardModel(PreTrainedPolicy): self.subtask_model.to(device) return self + def compute_reward(self, batch: dict[str, Tensor]) -> Tensor: + """Compute dense progress reward in [0, 1] from batch. + + Expects batch to contain: + - "observation_features" or video embeddings: (B, T, 512) + - "language_embedding" or text embeddings: (B, 512) + - optionally "observation.state": (B, T, state_dim) + """ + text_emb = batch.get("language_embedding", batch.get("text_features")) + video_emb = batch.get("observation_features", batch.get("video_features")) + state = batch.get("observation.state", batch.get("state_features")) + + rewards = self.calculate_rewards(text_emb, video_emb, state) + if isinstance(rewards, np.ndarray): + rewards = torch.from_numpy(rewards).float() + return rewards + @torch.no_grad() def calculate_rewards( self, @@ -631,17 +645,9 @@ class SARMRewardModel(PreTrainedPolicy): return self.parameters() def reset(self): - """Required by PreTrainedPolicy but not used for reward models.""" + """SARM has no episode-level state to reset.""" pass - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: - """Required by PreTrainedPolicy but not used for reward models.""" - raise NotImplementedError("SARM model does not predict action chunks") - - def select_action(self, batch: dict[str, Tensor]) -> Tensor: - """Required by PreTrainedPolicy but not used for SARM.""" - raise NotImplementedError("SARM model does not select actions") - def _train_step( self, img_emb: torch.Tensor, # (B, N, T, D) diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/rewards/sarm/processor_sarm.py similarity index 99% rename from src/lerobot/policies/sarm/processor_sarm.py rename to src/lerobot/rewards/sarm/processor_sarm.py index b60271b49..eaa5f66f5 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/rewards/sarm/processor_sarm.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -60,16 +58,15 @@ from lerobot.processor import ( policy_action_to_transition, transition_to_policy_action, ) -from lerobot.types import EnvTransition, PolicyAction, TransitionKey -from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME - -from .configuration_sarm import SARMConfig -from .sarm_utils import ( +from lerobot.rewards.sarm.configuration_sarm import SARMConfig +from lerobot.rewards.sarm.sarm_utils import ( apply_rewind_augmentation, compute_absolute_indices, find_stage_and_tau, pad_state_to_max_dim, ) +from lerobot.types import EnvTransition, PolicyAction, TransitionKey +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME class SARMEncodingProcessorStep(ProcessorStep): diff --git a/src/lerobot/utils/rabc.py b/src/lerobot/rewards/sarm/rabc.py similarity index 79% rename from src/lerobot/utils/rabc.py rename to src/lerobot/rewards/sarm/rabc.py index dc0c61c69..8d7ce6bde 100644 --- a/src/lerobot/utils/rabc.py +++ b/src/lerobot/rewards/sarm/rabc.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,14 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation. + +This module implements the SampleWeighter protocol for RA-BC training, +which weights training samples based on their task progress as measured +by the SARM reward model. + +The weights are computed based on progress deltas: + delta = progress[t + chunk_size] - progress[t] + +High-quality samples (positive progress) get higher weights, while +samples with negative progress (going backwards) get zero weight. + +See: https://arxiv.org/abs/2509.25358 for the SARM paper. +""" + import logging from pathlib import Path +from typing import TYPE_CHECKING import numpy as np -import pandas as pd import torch from huggingface_hub import hf_hub_download +from lerobot.utils.import_utils import _pandas_available +from lerobot.utils.sample_weighting import SampleWeighter + +if TYPE_CHECKING or _pandas_available: + import pandas as pd +else: + pd = None # type: ignore[assignment] + def resolve_hf_path(path: str | Path) -> Path: """Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path.""" @@ -34,23 +56,27 @@ def resolve_hf_path(path: str | Path) -> Path: return Path(path) -class RABCWeights: +class RABCWeights(SampleWeighter): """ Load precomputed SARM progress values and compute RA-BC weights during training. + This class implements the SampleWeighter ABC for use with the generic + sample weighting infrastructure in lerobot. + Progress values are loaded from a parquet file (generated by compute_rabc_weights.py). During training, computes: - progress_delta = progress[t + chunk_size] - progress[t] - rabc_weight based on the delta (paper Eq. 8-9) Args: - progress_path: Path to parquet file with precomputed progress values - chunk_size: Number of frames ahead for computing progress delta - head_mode: Which SARM head to use ("sparse" or "dense") - kappa: Hard threshold for high-quality samples (default: 0.01) - epsilon: Small constant for numerical stability (default: 1e-6) - fallback_weight: Weight to use for frames without valid delta (default: 1.0) - device: Device to return tensors on + progress_path: Path to parquet file with precomputed progress values. + Supports HuggingFace URLs (hf://datasets/...). + chunk_size: Number of frames ahead for computing progress delta. + head_mode: Which SARM head to use ("sparse" or "dense"). + kappa: Hard threshold for high-quality samples (default: 0.01). + epsilon: Small constant for numerical stability (default: 1e-6). + fallback_weight: Weight to use for frames without valid delta (default: 1.0). + device: Device to return tensors on. """ def __init__( @@ -61,7 +87,7 @@ class RABCWeights: kappa: float = 0.01, epsilon: float = 1e-6, fallback_weight: float = 1.0, - device: torch.device = None, + device: torch.device | None = None, ): self.progress_path = resolve_hf_path(progress_path) self.chunk_size = chunk_size @@ -87,8 +113,8 @@ class RABCWeights: logging.info(f"Using progress column: {self.progress_column}") - self.progress_lookup = {} - self.episode_lookup = {} + self.progress_lookup: dict[int, float] = {} + self.episode_lookup: dict[int, int] = {} for _, row in self.df.iterrows(): global_idx = int(row["index"]) @@ -100,7 +126,7 @@ class RABCWeights: self.episode_lookup[global_idx] = episode_idx # Build episode boundaries for delta computation - self.episode_boundaries = {} + self.episode_boundaries: dict[int, dict[str, int]] = {} for episode_idx in self.df["episode_index"].unique(): ep_df = self.df[self.df["episode_index"] == episode_idx] self.episode_boundaries[int(episode_idx)] = { @@ -114,7 +140,7 @@ class RABCWeights: # Compute global statistics for weight computation self._compute_global_stats() - def _compute_global_stats(self): + def _compute_global_stats(self) -> None: """Compute global mean and std of progress deltas for weight calculation.""" all_deltas = [] @@ -138,8 +164,8 @@ class RABCWeights: all_deltas.append(delta) if all_deltas: - self.delta_mean = max(np.mean(all_deltas), 0.0) - self.delta_std = max(np.std(all_deltas), self.epsilon) + self.delta_mean = max(float(np.mean(all_deltas)), 0.0) + self.delta_std = max(float(np.std(all_deltas)), self.epsilon) logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}") else: self.delta_mean = 0.0 @@ -157,18 +183,19 @@ class RABCWeights: 4. Compute weight using paper Eq. 8-9 Args: - batch: Training batch containing "index" key with global frame indices + batch: Training batch containing "index" key with global frame indices. Returns: Tuple of: - - Weights tensor (batch_size,) normalized to sum to batch_size - - Stats dict with raw_mean_weight, num_zero_weight, num_full_weight + - Weights tensor (batch_size,) normalized to sum to batch_size. + - Stats dict with weighting statistics for logging. """ indices = batch.get("index") if indices is None: logging.warning("RA-BC: Batch missing 'index' key, using uniform weights") batch_size = self._get_batch_size(batch) - return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0} + stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size} + return torch.ones(batch_size, device=self.device), stats # Convert to list of ints if isinstance(indices, torch.Tensor): @@ -183,29 +210,29 @@ class RABCWeights: delta = self._compute_delta(idx) deltas.append(delta) - deltas = np.array(deltas, dtype=np.float32) + deltas_array = np.array(deltas, dtype=np.float32) # Compute weights from deltas - weights = self._compute_weights(deltas) + weights = self._compute_weights(deltas_array) # Compute stats before normalization for logging raw_mean_weight = float(np.nanmean(weights)) num_zero_weight = int(np.sum(weights == 0)) num_full_weight = int(np.sum(weights == 1.0)) batch_stats = { - "raw_mean_weight": raw_mean_weight, + "mean_weight": raw_mean_weight, "num_zero_weight": num_zero_weight, "num_full_weight": num_full_weight, } - weights = torch.tensor(weights, device=self.device, dtype=torch.float32) + weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32) # Normalize to sum to batch_size - batch_size = len(weights) - weight_sum = weights.sum() + self.epsilon - weights = weights * batch_size / weight_sum + batch_size = len(weights_tensor) + weight_sum = weights_tensor.sum() + self.epsilon + weights_tensor = weights_tensor * batch_size / weight_sum - return weights, batch_stats + return weights_tensor, batch_stats def _compute_delta(self, global_idx: int) -> float: """Compute progress delta for a single frame.""" @@ -241,7 +268,7 @@ class RABCWeights: - Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi Returns: - Array of weights + Array of weights. """ valid_mask = ~np.isnan(deltas) @@ -273,12 +300,13 @@ class RABCWeights: if key in batch: val = batch[key] if isinstance(val, (torch.Tensor, np.ndarray)): - return val.shape[0] + return int(val.shape[0]) return 1 def get_stats(self) -> dict: - """Get statistics.""" + """Get global statistics about the RA-BC weighting.""" return { + "type": "rabc", "num_frames": len(self.progress_lookup), "chunk_size": self.chunk_size, "head_mode": self.head_mode, diff --git a/src/lerobot/policies/sarm/sarm_utils.py b/src/lerobot/rewards/sarm/sarm_utils.py similarity index 99% rename from src/lerobot/policies/sarm/sarm_utils.py rename to src/lerobot/rewards/sarm/sarm_utils.py index 5b6955d38..d2cd92cff 100644 --- a/src/lerobot/policies/sarm/sarm_utils.py +++ b/src/lerobot/rewards/sarm/sarm_utils.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 856006507..9d7330e6a 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset from lerobot.envs import close_envs, make_env, make_env_pre_post_processors from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors +from lerobot.rewards import make_reward_pre_post_processors from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed @@ -70,8 +71,8 @@ def update_policy( accelerator: "Accelerator", lr_scheduler=None, lock=None, - rabc_weights_provider=None, -) -> tuple[MetricsTracker, dict]: + sample_weighter=None, +) -> tuple[MetricsTracker, dict | None]: """ Performs a single training step to update the policy's weights. @@ -87,7 +88,7 @@ def update_policy( accelerator: The Accelerator instance for distributed training and mixed precision. lr_scheduler: An optional learning rate scheduler. lock: An optional lock for thread-safe optimizer updates. - rabc_weights_provider: Optional RABCWeights instance for sample weighting. + sample_weighter: Optional SampleWeighter instance for per-sample loss weighting. Returns: A tuple containing: @@ -97,27 +98,31 @@ def update_policy( start_time = time.perf_counter() policy.train() - # Get RA-BC weights if enabled - rabc_batch_weights = None - rabc_batch_stats = None - if rabc_weights_provider is not None: - rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch) + # Compute sample weights if a weighter is provided + sample_weights = None + weight_stats = None + if sample_weighter is not None: + sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch) # Let accelerator handle mixed precision with accelerator.autocast(): - # Use per-sample loss when RA-BC is enabled for proper weighting - if rabc_batch_weights is not None: - # Get per-sample losses + if sample_weights is not None: + # Use per-sample loss for weighted training + # Note: Policies supporting sample weighting must implement forward(batch, reduction="none") per_sample_loss, output_dict = policy.forward(batch, reduction="none") - # Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε) - # rabc_batch_weights is already normalized to sum to batch_size + # Weighted loss: each sample's contribution is scaled by its weight. + # We divide by weight sum (not batch size) so that if some weights are zero, + # the remaining samples contribute proportionally more, preserving gradient scale. + # Weights are pre-normalized to sum to batch_size for stable training dynamics. epsilon = 1e-6 - loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon) - # Log raw mean weight (before normalization) - this is the meaningful metric - output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"] - output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"] - output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"] + loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon) + + # Log weighting statistics + if output_dict is None: + output_dict = {} + for key, value in weight_stats.items(): + output_dict[f"sample_weight_{key}"] = value else: loss, output_dict = policy.forward(batch) @@ -188,8 +193,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) # Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting. - # Force the device to be CPU when policy.device is set to CPU. - force_cpu = cfg.policy.device == "cpu" + # Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training). + force_cpu = cfg.trainable_config.device == "cpu" accelerator = Accelerator( step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs], @@ -245,26 +250,44 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) - if is_main_process: - logging.info("Creating policy") - policy = make_policy( - cfg=cfg.policy, - ds_meta=dataset.meta, - rename_map=cfg.rename_map, - ) + if cfg.is_reward_model_training: + if is_main_process: + logging.info("Creating reward model") + from lerobot.rewards import make_reward_model + + policy = make_reward_model( + cfg=cfg.reward_model, + dataset_stats=dataset.meta.stats, + dataset_meta=dataset.meta, + ) + if not policy.is_trainable: + raise ValueError( + f"Reward model '{policy.name}' is zero-shot and cannot be trained via lerobot-train. " + "Use it directly for inference via compute_reward() (e.g. offline precompute)." + ) + else: + if is_main_process: + logging.info("Creating policy") + policy = make_policy( + cfg=cfg.policy, + ds_meta=dataset.meta, + rename_map=cfg.rename_map, + ) if cfg.peft is not None: + if cfg.is_reward_model_training: + raise ValueError("PEFT is only supported for policy training. ") logging.info("Using PEFT! Wrapping model.") - # Convert CLI peft config to dict for overrides peft_cli_overrides = dataclasses.asdict(cfg.peft) policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides) - # Wait for all processes to finish policy creation before continuing + # Wait for all processes to finish model creation before continuing accelerator.wait_for_everyone() - processor_pretrained_path = cfg.policy.pretrained_path + active_cfg = cfg.trainable_config + processor_pretrained_path = active_cfg.pretrained_path if ( - getattr(cfg.policy, "use_relative_actions", False) + getattr(active_cfg, "use_relative_actions", False) and processor_pretrained_path is not None and not cfg.resume ): @@ -274,18 +297,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): ) processor_pretrained_path = None - # Create processors - only provide dataset_stats if not resuming from saved processors processor_kwargs = {} postprocessor_kwargs = {} if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path: - # Only provide dataset_stats when not resuming from saved processor state processor_kwargs["dataset_stats"] = dataset.meta.stats - # For SARM, always provide dataset_meta for progress normalization - if cfg.policy.type == "sarm": + if cfg.is_reward_model_training: processor_kwargs["dataset_meta"] = dataset.meta - if processor_pretrained_path is not None: + if not cfg.is_reward_model_training and processor_pretrained_path is not None: processor_kwargs["preprocessor_overrides"] = { "device_processor": {"device": device.type}, "normalizer_processor": { @@ -305,38 +325,36 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): }, } - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=processor_pretrained_path, - **processor_kwargs, - **postprocessor_kwargs, - ) + if cfg.is_reward_model_training: + preprocessor, postprocessor = make_reward_pre_post_processors( + cfg.reward_model, + **processor_kwargs, + ) + else: + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=processor_pretrained_path, + **processor_kwargs, + **postprocessor_kwargs, + ) if is_main_process: logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) - # Load precomputed SARM progress for RA-BC if enabled - # Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py - rabc_weights = None - if cfg.use_rabc: - from lerobot.utils.rabc import RABCWeights + # Create sample weighter if configured (e.g., for RA-BC training) + sample_weighter = None + if cfg.sample_weighting is not None: + from lerobot.utils.sample_weighting import make_sample_weighter - # Get chunk_size from policy config - chunk_size = getattr(policy.config, "chunk_size", None) - if chunk_size is None: - raise ValueError("Chunk size is not found in policy config") - - head_mode = getattr(cfg, "rabc_head_mode", "sparse") - logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}") - logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}") - rabc_weights = RABCWeights( - progress_path=cfg.rabc_progress_path, - chunk_size=chunk_size, - head_mode=head_mode, - kappa=getattr(cfg, "rabc_kappa", 0.01), - epsilon=getattr(cfg, "rabc_epsilon", 1e-6), - device=device, + if is_main_process: + logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}") + sample_weighter = make_sample_weighter( + cfg.sample_weighting, + policy, + device, + dataset_root=cfg.dataset.root, + dataset_repo_id=cfg.dataset.repo_id, ) step = 0 # number of policy updates (forward + backward + optim) @@ -365,13 +383,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training - if hasattr(cfg.policy, "drop_n_last_frames"): + if hasattr(active_cfg, "drop_n_last_frames"): shuffle = False sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], episode_indices_to_use=dataset.episodes, - drop_n_last_frames=cfg.policy.drop_n_last_frames, + drop_n_last_frames=active_cfg.drop_n_last_frames, shuffle=True, ) else: @@ -448,7 +466,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): cfg.optimizer.grad_clip_norm, accelerator=accelerator, lr_scheduler=lr_scheduler, - rabc_weights_provider=rabc_weights, + sample_weighter=sample_weighter, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we @@ -467,16 +485,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): wandb_log_dict = train_tracker.to_dict() if output_dict: wandb_log_dict.update(output_dict) - # Log RA-BC statistics if enabled - if rabc_weights is not None: - rabc_stats = rabc_weights.get_stats() - wandb_log_dict.update( - { - "rabc_delta_mean": rabc_stats["delta_mean"], - "rabc_delta_std": rabc_stats["delta_std"], - "rabc_num_frames": rabc_stats["num_frames"], - } - ) + # Log sample weighting statistics if enabled + if sample_weighter is not None: + weighter_stats = sample_weighter.get_stats() + wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()}) wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() @@ -558,14 +570,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if is_main_process: logging.info("End of training") - if cfg.policy.push_to_hub: - unwrapped_policy = accelerator.unwrap_model(policy) - if cfg.policy.use_peft: - unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy) + if getattr(active_cfg, "push_to_hub", False): + unwrapped_model = accelerator.unwrap_model(policy) + # PEFT only applies when training a policy — reward models use the plain path. + if not cfg.is_reward_model_training and cfg.policy.use_peft: + unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model) else: - unwrapped_policy.push_model_to_hub(cfg) - preprocessor.push_to_hub(cfg.policy.repo_id) - postprocessor.push_to_hub(cfg.policy.repo_id) + unwrapped_model.push_model_to_hub(cfg) + preprocessor.push_to_hub(active_cfg.repo_id) + postprocessor.push_to_hub(active_cfg.repo_id) # Properly clean up the distributed process group accelerator.wait_for_everyone() diff --git a/src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md b/src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md new file mode 100644 index 000000000..933bf7586 --- /dev/null +++ b/src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md @@ -0,0 +1,55 @@ +--- +# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/model-cards +# prettier-ignore +{{card_data}} +--- + +# Reward Model Card for {{ model_name | default("Reward Model ID", true) }} + + + +{% if model_name == "reward_classifier" %} +A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable. +{% elif model_name == "sarm" %} +A Success-Aware Reward Model (SARM) predicts a dense reward signal from observations, typically used downstream for reinforcement learning or human-in-the-loop fine-tuning when task success is not directly observable. +{% else %} +_Reward model type not recognized — please update this template._ +{% endif %} + +This reward model has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot). +See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index). + +--- + +## How to Get Started with the Reward Model + +### Train from scratch + +```bash +lerobot-train \ + --dataset.repo_id=${HF_USER}/ \ + --reward_model.type={{ model_name | default("reward_classifier", true) }} \ + --output_dir=outputs/train/ \ + --job_name=lerobot_reward_training \ + --reward_model.device=cuda \ + --reward_model.repo_id=${HF_USER}/ \ + --wandb.enable=true +``` + +_Writes checkpoints to `outputs/train//checkpoints/`._ + +### Load the reward model in Python + +```python +from lerobot.rewards import make_reward_model + +reward_model = make_reward_model(pretrained_path="/") +reward = reward_model.compute_reward(batch) +``` + +--- + +## Model Details + +- **License:** {{ license | default("\[More Information Needed]", true) }} diff --git a/src/lerobot/utils/sample_weighting.py b/src/lerobot/utils/sample_weighting.py new file mode 100644 index 000000000..83eec3126 --- /dev/null +++ b/src/lerobot/utils/sample_weighting.py @@ -0,0 +1,239 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample weighting abstraction for training. + +This module provides an abstract base class for sample weighting strategies (e.g., RA-BC) +that can be used during training without polluting the training script with +policy-specific code. + +Example usage: + # In training config + sample_weighting: + type: rabc + progress_path: hf://datasets/my-dataset/sarm_progress.parquet + head_mode: sparse + kappa: 0.01 + + # In training script + sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device, dataset_root=cfg.dataset.root, dataset_repo_id=cfg.dataset.repo_id) + ... + weights, stats = sample_weighter.compute_batch_weights(batch) +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from lerobot.policies.pretrained import PreTrainedPolicy + + +class SampleWeighter(ABC): + """ + Implementations compute per-sample weights that can be used to weight + the loss during training. This enables techniques like: + - RA-BC (Reward-Aligned Behavior Cloning) + - Importance sampling + - Curriculum learning + - Quality-based filtering + """ + + @abstractmethod + def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]: + """ + Compute per-sample weights for a training batch. + + Args: + batch: Training batch dictionary containing at minimum an "index" key + with global frame indices. + """ + + @abstractmethod + def get_stats(self) -> dict: + """ + Get global statistics about the weighting strategy. + """ + + +@dataclass +class SampleWeightingConfig: + """ + Configuration for sample weighting during training. + + This is a generic config that supports multiple weighting strategies. + The `type` field determines which implementation to use, and `extra_params` + contains additional type-specific parameters. + + Attributes: + type: Weighting strategy type ("rabc", "uniform", etc.) + progress_path: Path to precomputed progress values (for RABC) + head_mode: Which model head to use for progress ("sparse" or "dense") + kappa: Hard threshold for high-quality samples (RABC-specific) + epsilon: Small constant for numerical stability + extra_params: Additional type-specific parameters passed to the weighter + """ + + type: str = "rabc" + progress_path: str | None = None + head_mode: str = "sparse" + kappa: float = 0.01 + epsilon: float = 1e-6 + # Additional type-specific params can be added here or passed via extra_params + extra_params: dict = field(default_factory=dict) + + +def make_sample_weighter( + config: SampleWeightingConfig | None, + policy: PreTrainedPolicy, + device: torch.device, + dataset_root: str | None = None, + dataset_repo_id: str | None = None, +) -> SampleWeighter | None: + """ + Factory function to create a SampleWeighter from config. + + This keeps policy-specific initialization logic out of the training script. + + Args: + config: Sample weighting configuration, or None to disable weighting. + policy: The policy being trained (used to extract chunk_size, etc.) + device: Device to place weight tensors on. + dataset_root: Local path to dataset root (for auto-detecting progress_path). + dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path). + """ + if config is None: + return None + + if config.type == "rabc": + return _make_rabc_weighter(config, policy, device, dataset_root, dataset_repo_id) + + if config.type == "uniform": + # No-op weighter that returns uniform weights + return UniformWeighter(device=device) + + raise ValueError(f"Unknown sample weighting type: '{config.type}'. Supported types: 'rabc', 'uniform'") + + +def _make_rabc_weighter( + config: SampleWeightingConfig, + policy: PreTrainedPolicy, + device: torch.device, + dataset_root: str | None = None, + dataset_repo_id: str | None = None, +) -> SampleWeighter: + """Create RABC weighter with policy-specific initialization. + + Args: + config: Sample weighting configuration. + policy: The policy being trained (used to extract chunk_size). + device: Device to place weight tensors on. + dataset_root: Local path to dataset root (for auto-detecting progress_path). + dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path). + """ + # Import here to avoid circular imports and keep RABC code in SARM module + from lerobot.rewards.sarm.rabc import RABCWeights + + # Extract chunk_size from policy config + chunk_size = getattr(policy.config, "chunk_size", None) + if chunk_size is None: + raise ValueError( + "RABC sample weighting requires a policy with 'chunk_size' in its config. " + "This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc." + ) + + # Determine progress_path: use explicit config or auto-detect from dataset + progress_path = config.progress_path + if progress_path is None: + if dataset_root: + progress_path = str(Path(dataset_root) / "sarm_progress.parquet") + elif dataset_repo_id: + progress_path = f"hf://datasets/{dataset_repo_id}/sarm_progress.parquet" + else: + raise ValueError( + "RABC sample weighting requires 'progress_path' to be set, " + "or dataset_root/dataset_repo_id for auto-detection. " + "Generate progress values using: " + "python -m lerobot.rewards.sarm.compute_rabc_weights --help" + ) + + return RABCWeights( + progress_path=progress_path, + chunk_size=chunk_size, + head_mode=config.head_mode, + kappa=config.kappa, + epsilon=config.epsilon, + device=device, + **config.extra_params, + ) + + +class UniformWeighter(SampleWeighter): + """ + No-op sample weighter that returns uniform weights. + + Useful as a baseline or when you want to disable weighting without + changing the training code structure. + + Note: + Batch size is determined by looking for tensor values in the batch + dictionary. The method checks common keys like "action", "index", + and "observation.state" first, then falls back to scanning all values. + """ + + def __init__(self, device: torch.device): + self.device = device + + def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]: + """Return uniform weights (all ones).""" + batch_size = self._determine_batch_size(batch) + + weights = torch.ones(batch_size, device=self.device) + stats = {"mean_weight": 1.0, "type": "uniform"} + return weights, stats + + def _determine_batch_size(self, batch: dict) -> int: + """ + Determine batch size from the batch dictionary. + + Checks common keys first, then scans all values for tensors. + + Args: + batch: Training batch dictionary. + """ + if not batch: + raise ValueError("Cannot determine batch size from empty batch") + + # Check common keys first + for key in ["action", "index", "observation.state"]: + if key in batch and isinstance(batch[key], torch.Tensor): + return batch[key].shape[0] + + # Scan all values for any tensor + for value in batch.values(): + if isinstance(value, torch.Tensor) and value.ndim >= 1: + return value.shape[0] + + # Last resort: return 1 (this handles non-tensor batches) + return 1 + + def get_stats(self) -> dict: + """Return empty stats for uniform weighting.""" + return {"type": "uniform"} diff --git a/tests/rewards/__init__.py b/tests/rewards/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/processor/test_classifier_processor.py b/tests/rewards/test_classifier_processor.py similarity index 92% rename from tests/processor/test_classifier_processor.py rename to tests/rewards/test_classifier_processor.py index e1567bf29..c54d80b0e 100644 --- a/tests/processor/test_classifier_processor.py +++ b/tests/rewards/test_classifier_processor.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,8 +19,6 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig -from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor from lerobot.processor import ( DataProcessorPipeline, DeviceProcessorStep, @@ -31,6 +27,8 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import create_transition, transition_to_batch +from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig +from lerobot.rewards.classifier.processor_classifier import make_classifier_processor from lerobot.utils.constants import OBS_IMAGE, OBS_STATE @@ -42,7 +40,7 @@ def create_default_config(): OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { - "reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), # Classifier output + "reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), } config.normalization_mapping = { FeatureType.STATE: NormalizationMode.MEAN_STD, @@ -90,17 +88,14 @@ def test_classifier_processor_normalization(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_classifier_processor( - config, - stats, - ) + preprocessor, postprocessor = make_classifier_processor(config, stats) # Create test data observation = { OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224), } - action = torch.randn(1) # Dummy action/reward + action = torch.randn(1) transition = create_transition(observation, action) batch = transition_to_batch(transition) @@ -120,10 +115,7 @@ def test_classifier_processor_cuda(): config.device = "cuda" stats = create_default_stats() - preprocessor, postprocessor = make_classifier_processor( - config, - stats, - ) + preprocessor, postprocessor = make_classifier_processor(config, stats) # Create CPU data observation = { @@ -132,7 +124,6 @@ def test_classifier_processor_cuda(): } action = torch.randn(1) transition = create_transition(observation, action) - batch = transition_to_batch(transition) # Process through preprocessor @@ -158,10 +149,7 @@ def test_classifier_processor_accelerate_scenario(): config.device = "cuda:0" stats = create_default_stats() - preprocessor, postprocessor = make_classifier_processor( - config, - stats, - ) + preprocessor, postprocessor = make_classifier_processor(config, stats) # Simulate Accelerate: data already on GPU device = torch.device("cuda:0") @@ -171,7 +159,6 @@ def test_classifier_processor_accelerate_scenario(): } action = torch.randn(1).to(device) transition = create_transition(observation, action) - batch = transition_to_batch(transition) # Process through preprocessor @@ -201,7 +188,6 @@ def test_classifier_processor_multi_gpu(): } action = torch.randn(1).to(device) transition = create_transition(observation, action) - batch = transition_to_batch(transition) # Process through preprocessor @@ -231,7 +217,6 @@ def test_classifier_processor_without_stats(): } action = torch.randn(1) transition = create_transition(observation, action) - batch = transition_to_batch(transition) processed = preprocessor(batch) @@ -294,7 +279,6 @@ def test_classifier_processor_mixed_precision(): } action = torch.randn(1, dtype=torch.float32) transition = create_transition(observation, action) - batch = transition_to_batch(transition) # Process through preprocessor @@ -312,10 +296,7 @@ def test_classifier_processor_batch_data(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_classifier_processor( - config, - stats, - ) + preprocessor, postprocessor = make_classifier_processor(config, stats) # Test with batched data batch_size = 16 @@ -325,7 +306,6 @@ def test_classifier_processor_batch_data(): } action = torch.randn(batch_size, 1) transition = create_transition(observation, action) - batch = transition_to_batch(transition) # Process through preprocessor @@ -343,15 +323,11 @@ def test_classifier_processor_postprocessor_identity(): config = create_default_config() stats = create_default_stats() - preprocessor, postprocessor = make_classifier_processor( - config, - stats, - ) + preprocessor, postprocessor = make_classifier_processor(config, stats) # Create test data for postprocessor - reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions + reward = torch.tensor([[0.8], [0.3], [0.9]]) transition = create_transition(action=reward) - _ = transition_to_batch(transition) # Process through postprocessor diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/rewards/test_modeling_classifier.py similarity index 86% rename from tests/policies/hilserl/test_modeling_classifier.py rename to tests/rewards/test_modeling_classifier.py index 6d262c01b..08f6121a1 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/rewards/test_modeling_classifier.py @@ -1,5 +1,3 @@ -# !/usr/bin/env python - # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,8 +16,8 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig -from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput +from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig +from lerobot.rewards.classifier.modeling_classifier import ClassifierOutput from lerobot.utils.constants import OBS_IMAGE, REWARD from tests.utils import skip_if_package_missing @@ -42,7 +40,7 @@ def test_classifier_output(): reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" ) def test_binary_classifier_with_default_params(): - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier + from lerobot.rewards.classifier.modeling_classifier import Classifier config = RewardClassifierConfig() config.input_features = { @@ -86,7 +84,7 @@ def test_binary_classifier_with_default_params(): reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" ) def test_multiclass_classifier(): - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier + from lerobot.rewards.classifier.modeling_classifier import Classifier num_classes = 5 config = RewardClassifierConfig() @@ -128,11 +126,15 @@ def test_multiclass_classifier(): reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" ) def test_default_device(): - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier + from lerobot.rewards.classifier.modeling_classifier import Classifier config = RewardClassifierConfig() - assert config.device == "cpu" + assert config.device is None or config.device == "cpu" + config.input_features = { + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.num_cameras = 1 classifier = Classifier(config) for p in classifier.parameters(): assert p.device == torch.device("cpu") @@ -143,11 +145,15 @@ def test_default_device(): reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" ) def test_explicit_device_setup(): - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier + from lerobot.rewards.classifier.modeling_classifier import Classifier config = RewardClassifierConfig(device="cpu") assert config.device == "cpu" + config.input_features = { + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.num_cameras = 1 classifier = Classifier(config) for p in classifier.parameters(): assert p.device == torch.device("cpu") diff --git a/tests/rewards/test_reward_model_base.py b/tests/rewards/test_reward_model_base.py new file mode 100644 index 000000000..c8755a0fa --- /dev/null +++ b/tests/rewards/test_reward_model_base.py @@ -0,0 +1,373 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the reward model base classes and registry.""" + +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch + +from lerobot.configs.rewards import RewardModelConfig +from lerobot.optim.optimizers import AdamWConfig +from lerobot.rewards.pretrained import PreTrainedRewardModel + + +@RewardModelConfig.register_subclass(name="_dummy_hub_reward") +@dataclass +class _DummyHubRewardConfig(RewardModelConfig): + def get_optimizer_preset(self): + return AdamWConfig(lr=1e-4) + + +class _DummyHubReward(PreTrainedRewardModel): + config_class = _DummyHubRewardConfig + name = "_dummy_hub_reward" + + def __init__(self, config): + super().__init__(config) + self.bias = torch.nn.Parameter(torch.zeros(1)) + + def compute_reward(self, batch): + return self.bias.expand(1) + + +def test_reward_model_config_registry(): + """Verify that classifier and sarm are registered.""" + known = RewardModelConfig.get_known_choices() + assert "reward_classifier" in known + assert "sarm" in known + + +def test_reward_model_config_lookup(): + """Verify that we can look up configs by name.""" + cls = RewardModelConfig.get_choice_class("reward_classifier") + from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig + + assert cls is RewardClassifierConfig + + +def test_factory_get_reward_model_class(): + """Test the get_reward_model_class factory.""" + from lerobot.rewards.factory import get_reward_model_class + + cls = get_reward_model_class("sarm") + from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel + + assert cls is SARMRewardModel + + +def test_factory_unknown_raises(): + """Unknown name should raise ValueError.""" + from lerobot.rewards.factory import get_reward_model_class + + with pytest.raises(ValueError, match="not available"): + get_reward_model_class("nonexistent_reward_model") + + +def test_pretrained_reward_model_requires_config_class(): + """Subclass without config_class should fail.""" + with pytest.raises(TypeError, match="must define 'config_class'"): + + class BadModel(PreTrainedRewardModel): + name = "bad" + + def compute_reward(self, batch): + pass + + +def test_pretrained_reward_model_requires_name(): + """Subclass without name should fail.""" + with pytest.raises(TypeError, match="must define 'name'"): + + class BadModel(PreTrainedRewardModel): + config_class = RewardModelConfig + + def compute_reward(self, batch): + pass + + +def test_non_trainable_forward_raises(): + """Non-trainable model should raise on forward().""" + from dataclasses import dataclass + + from lerobot.optim.optimizers import AdamWConfig + + @dataclass + class DummyConfig(RewardModelConfig): + def get_optimizer_preset(self): + return AdamWConfig(lr=1e-4) + + class DummyReward(PreTrainedRewardModel): + config_class = DummyConfig + name = "dummy_test" + + def compute_reward(self, batch): + return torch.zeros(1) + + config = DummyConfig() + model = DummyReward(config) + + with pytest.raises(NotImplementedError, match="not trainable"): + model.forward({"x": torch.zeros(1)}) + + +# --------------------------------------------------------------------------- +# Trainable vs zero-shot (general-purpose) reward models. +# The proposal explicitly supports models like TOPReward that wrap a pretrained +# VLM and produce a reward signal without any training step. These tests pin +# the contract that lets such models coexist with trainable ones. +# --------------------------------------------------------------------------- + + +def test_is_trainable_false_when_forward_not_overridden(): + """A reward model that only implements ``compute_reward`` is zero-shot.""" + model, _ = _make_dummy_reward_model() + assert model.is_trainable is False + + +def test_is_trainable_true_when_forward_overridden(): + """Overriding ``forward`` flips ``is_trainable`` to True.""" + + class _TrainableReward(_DummyHubReward): + name = "_trainable_dummy_reward" + + def forward(self, batch): + loss = (self.bias**2).sum() + return loss, {} + + # Register a fresh config subclass so the subclass check passes. + @RewardModelConfig.register_subclass(name="_trainable_dummy_reward") + @dataclass + class _TrainableConfig(_DummyHubRewardConfig): + pass + + _TrainableReward.config_class = _TrainableConfig + model = _TrainableReward(_TrainableConfig()) + assert model.is_trainable is True + + +# --------------------------------------------------------------------------- +# RewardModelConfig.from_pretrained +# --------------------------------------------------------------------------- + + +def test_reward_model_config_from_pretrained_raises_when_config_missing(tmp_path): + """``from_pretrained`` must surface a clear ``FileNotFoundError`` when the + target directory exists but does not contain ``config.json``, instead of + crashing later inside ``draccus.parse``. + """ + # tmp_path exists but has no config.json + with pytest.raises(FileNotFoundError, match="config.json not found"): + RewardModelConfig.from_pretrained(tmp_path) + + +def test_reward_model_config_from_pretrained_roundtrip(tmp_path): + """Round-trip: save a RewardClassifierConfig, reload it, fields must match.""" + from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig + + original = RewardClassifierConfig( + num_classes=3, + hidden_dim=128, + latent_dim=64, + num_cameras=1, + learning_rate=5e-4, + ) + original._save_pretrained(tmp_path) + + loaded = RewardModelConfig.from_pretrained(tmp_path) + + assert isinstance(loaded, RewardClassifierConfig) + assert loaded.num_classes == 3 + assert loaded.hidden_dim == 128 + assert loaded.latent_dim == 64 + assert loaded.num_cameras == 1 + assert loaded.learning_rate == 5e-4 + + +# --------------------------------------------------------------------------- +# TrainPipelineConfig — reward model training path +# --------------------------------------------------------------------------- + + +def test_train_pipeline_config_path_fields_includes_reward_model(): + """``--reward_model.path=local/dir`` requires ``reward_model`` to be listed + as a draccus path-field on ``TrainPipelineConfig``.""" + from lerobot.configs.train import TrainPipelineConfig + + fields = TrainPipelineConfig.__get_path_fields__() + assert "policy" in fields + assert "reward_model" in fields + + +def test_train_pipeline_config_trainable_config_returns_reward_model_when_set(): + """When only ``reward_model`` is set, ``trainable_config`` (used by the + trainer for e.g. ``.device``) must return it — not ``None`` from ``policy``.""" + from lerobot.configs.default import DatasetConfig + from lerobot.configs.train import TrainPipelineConfig + from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig + + reward_cfg = RewardClassifierConfig(device="cpu") + cfg = TrainPipelineConfig( + dataset=DatasetConfig(repo_id="user/repo"), + reward_model=reward_cfg, + ) + + assert cfg.is_reward_model_training is True + assert cfg.trainable_config is reward_cfg + # This is what lerobot_train.py uses to decide force_cpu; ``cfg.policy.device`` + # would AttributeError here because policy is None. + assert cfg.trainable_config.device == "cpu" + + +def test_train_pipeline_config_trainable_config_returns_policy_when_set(): + """Mirror of the reward-model case: when only ``policy`` is set, + ``trainable_config`` must return it.""" + from lerobot.configs.default import DatasetConfig + from lerobot.configs.train import TrainPipelineConfig + from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig + + policy_cfg = DiffusionConfig(device="cpu") + cfg = TrainPipelineConfig( + dataset=DatasetConfig(repo_id="user/repo"), + policy=policy_cfg, + ) + + assert cfg.is_reward_model_training is False + assert cfg.trainable_config is policy_cfg + assert cfg.trainable_config.device == "cpu" + + +# --------------------------------------------------------------------------- +# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card. +# We test the generation side (offline) fully, and the upload side with HfApi +# mocked so nothing actually hits the network. +# --------------------------------------------------------------------------- + + +def _make_dummy_reward_model(**config_kwargs): + return _DummyHubReward(_DummyHubRewardConfig(**config_kwargs)), _DummyHubRewardConfig + + +@pytest.fixture +def _offline_model_card(monkeypatch): + """``ModelCard.validate`` does a live ``POST`` to huggingface.co — bypass it + so tests can run offline.""" + from huggingface_hub import ModelCard + + monkeypatch.setattr(ModelCard, "validate", lambda self, *a, **kw: None) + + +def test_reward_model_generate_model_card_renders_expected_fields(_offline_model_card): + """``generate_model_card`` must produce a card with the right metadata and + body, using the dedicated reward-model template.""" + model, _ = _make_dummy_reward_model( + license="mit", + tags=["robot", "sim"], + ) + + card = model.generate_model_card( + dataset_repo_id="user/my_dataset", + model_type=model.config.type, + license=model.config.license, + tags=model.config.tags, + ) + + # Metadata (YAML header) — ModelCardData fields. + assert card.data.license == "mit" + assert card.data.library_name == "lerobot" + assert card.data.pipeline_tag == "robotics" + assert "reward-model" in card.data.tags + assert model.config.type in card.data.tags + assert card.data.model_name == model.config.type + assert card.data.datasets == "user/my_dataset" + + # Body — specific to the reward-model template, NOT the policy one. + body = str(card) + assert "Reward Model Card" in body + assert "This reward model has been trained" in body + assert "--reward_model.type=" in body # reward-model-specific usage block + + +def test_reward_model_generate_model_card_uses_default_license(_offline_model_card): + """When config.license is None the card falls back to apache-2.0.""" + model, _ = _make_dummy_reward_model() + + card = model.generate_model_card( + dataset_repo_id="user/my_dataset", + model_type=model.config.type, + license=model.config.license, + tags=None, + ) + + assert card.data.license == "apache-2.0" + + +def test_reward_model_push_model_to_hub_uploads_expected_files(monkeypatch, _offline_model_card): + """``push_model_to_hub`` must: + 1. create the repo, + 2. assemble a temp folder with weights + config.json + train_config.json + README.md, + 3. call ``api.upload_folder`` on that folder. + All network calls are mocked. + """ + from huggingface_hub.constants import CONFIG_NAME + + from lerobot.configs.default import DatasetConfig + from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig + + model, _ = _make_dummy_reward_model( + repo_id="user/my_reward", + license="apache-2.0", + ) + # Point the reward model's train config at a dummy dataset repo. + train_cfg = TrainPipelineConfig( + dataset=DatasetConfig(repo_id="user/my_dataset"), + reward_model=model.config, + ) + + uploaded: dict = {} + fake_commit_info = SimpleNamespace(repo_url=SimpleNamespace(url="https://huggingface.co/user/my_reward")) + + class _FakeHfApi: + def create_repo(self, repo_id, private=None, exist_ok=False): + uploaded["create_repo_id"] = repo_id + uploaded["create_private"] = private + return SimpleNamespace(repo_id=repo_id) + + def upload_folder(self, *, repo_id, repo_type, folder_path, commit_message, **_kwargs): + uploaded["upload_repo_id"] = repo_id + uploaded["upload_repo_type"] = repo_type + uploaded["commit_message"] = commit_message + # Snapshot files assembled in the temp folder — this is the real + # contract we care about. + uploaded["files"] = sorted(p.name for p in Path(folder_path).iterdir()) + return fake_commit_info + + from lerobot.rewards import pretrained as reward_pretrained + + monkeypatch.setattr(reward_pretrained, "HfApi", lambda *a, **kw: _FakeHfApi()) + + model.push_model_to_hub(train_cfg) + + assert uploaded["create_repo_id"] == "user/my_reward" + assert uploaded["upload_repo_id"] == "user/my_reward" + assert uploaded["upload_repo_type"] == "model" + assert uploaded["commit_message"] == "Upload reward model weights, train config and readme" + # Minimum required files that must be uploaded with a reward model. + assert CONFIG_NAME in uploaded["files"] # config.json + assert TRAIN_CONFIG_NAME in uploaded["files"] # train_config.json + assert "README.md" in uploaded["files"] + assert any(name.endswith(".safetensors") for name in uploaded["files"]) diff --git a/tests/policies/test_sarm_processor.py b/tests/rewards/test_sarm_processor.py similarity index 97% rename from tests/policies/test_sarm_processor.py rename to tests/rewards/test_sarm_processor.py index 5b90784a6..65f70d396 100644 --- a/tests/policies/test_sarm_processor.py +++ b/tests/rewards/test_sarm_processor.py @@ -104,8 +104,8 @@ class TestSARMEncodingProcessorStepEndToEnd: def mock_clip_model(self): """Mock CLIP model to avoid loading real weights.""" with ( - patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls, - patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls, + patch("lerobot.rewards.sarm.processor_sarm.CLIPModel") as mock_model_cls, + patch("lerobot.rewards.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls, ): # Mock the CLIP model - return embeddings based on input batch size mock_model = MagicMock() @@ -142,7 +142,7 @@ class TestSARMEncodingProcessorStepEndToEnd: @pytest.fixture def processor_with_mocks(self, mock_clip_model): """Create a processor with mocked CLIP and dataset metadata for dual mode.""" - from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep # Dual mode config with both sparse and dense annotations config = MockConfig( @@ -256,7 +256,7 @@ class TestSARMEncodingProcessorStepEndToEnd: def test_call_with_batched_input(self, mock_clip_model): """Test processor __call__ with a batched input (multiple frames) in dual mode.""" - from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep config = MockConfig( n_obs_steps=8, @@ -332,7 +332,7 @@ class TestSARMEncodingProcessorStepEndToEnd: def test_targets_increase_with_progress(self, mock_clip_model): """Test that both sparse and dense targets increase as frame index progresses.""" - from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep config = MockConfig( n_obs_steps=8, @@ -404,7 +404,7 @@ class TestSARMEncodingProcessorStepEndToEnd: def test_progress_labels_exact_values(self, mock_clip_model): """Test that progress labels (stage.tau) are computed correctly for known positions.""" - from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep # Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode config = MockConfig( @@ -495,7 +495,7 @@ class TestSARMEncodingProcessorStepEndToEnd: """Test that rewind augmentation correctly extends sequence and generates targets.""" import random - from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep config = MockConfig( n_obs_steps=8, @@ -587,8 +587,8 @@ class TestSARMEncodingProcessorStepEndToEnd: def test_full_sequence_target_consistency(self, mock_clip_model): """Test that the full sequence of targets is consistent with frame positions.""" - from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep - from lerobot.policies.sarm.sarm_utils import find_stage_and_tau + from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep + from lerobot.rewards.sarm.sarm_utils import find_stage_and_tau config = MockConfig( n_obs_steps=8, diff --git a/tests/policies/test_sarm_utils.py b/tests/rewards/test_sarm_utils.py similarity index 99% rename from tests/policies/test_sarm_utils.py rename to tests/rewards/test_sarm_utils.py index 510477ec8..9ee542909 100644 --- a/tests/policies/test_sarm_utils.py +++ b/tests/rewards/test_sarm_utils.py @@ -18,7 +18,7 @@ import numpy as np import pytest import torch -from lerobot.policies.sarm.sarm_utils import ( +from lerobot.rewards.sarm.sarm_utils import ( apply_rewind_augmentation, compute_absolute_indices, compute_tau, diff --git a/tests/utils/test_sample_weighting.py b/tests/utils/test_sample_weighting.py new file mode 100644 index 000000000..4507d7f18 --- /dev/null +++ b/tests/utils/test_sample_weighting.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the sample weighting infrastructure.""" + +from unittest.mock import Mock + +import pytest + +pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])") + +import torch + +from lerobot.utils.sample_weighting import ( + SampleWeighter, + SampleWeightingConfig, + UniformWeighter, + make_sample_weighter, +) + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def sample_progress_parquet(tmp_path): + """Create a sample progress parquet file for testing.""" + import pandas as pd + + # Create sample progress data for 2 episodes with 10 frames each + data = { + "index": list(range(20)), + "episode_index": [0] * 10 + [1] * 10, + "frame_index": list(range(10)) * 2, + "progress_sparse": [i / 10.0 for i in range(10)] * 2, + } + df = pd.DataFrame(data) + parquet_path = tmp_path / "sarm_progress.parquet" + df.to_parquet(parquet_path) + return parquet_path + + +# ============================================================================= +# SampleWeightingConfig Tests +# ============================================================================= + + +def test_config_default_values(): + """Test default configuration values.""" + config = SampleWeightingConfig() + assert config.type == "rabc" + assert config.progress_path is None + assert config.head_mode == "sparse" + assert config.kappa == 0.01 + assert config.epsilon == 1e-6 + assert config.extra_params == {} + + +def test_config_custom_values(): + """Test configuration with custom values.""" + config = SampleWeightingConfig( + type="rabc", + progress_path="/path/to/progress.parquet", + head_mode="dense", + kappa=0.05, + epsilon=1e-8, + extra_params={"fallback_weight": 0.5}, + ) + assert config.type == "rabc" + assert config.progress_path == "/path/to/progress.parquet" + assert config.head_mode == "dense" + assert config.kappa == 0.05 + assert config.epsilon == 1e-8 + assert config.extra_params == {"fallback_weight": 0.5} + + +def test_config_uniform_type(): + """Test configuration for uniform weighting.""" + config = SampleWeightingConfig(type="uniform") + assert config.type == "uniform" + + +# ============================================================================= +# UniformWeighter Tests +# ============================================================================= + + +def test_uniform_weighter_inherits_from_sample_weighter(): + """Test that UniformWeighter is a SampleWeighter.""" + weighter = UniformWeighter(device=torch.device("cpu")) + assert isinstance(weighter, SampleWeighter) + + +def test_uniform_weighter_compute_batch_weights_with_action_key(): + """Test weight computation with 'action' key in batch.""" + weighter = UniformWeighter(device=torch.device("cpu")) + batch = {"action": torch.randn(8, 10)} + + weights, stats = weighter.compute_batch_weights(batch) + + assert weights.shape == (8,) + assert torch.allclose(weights, torch.ones(8)) + assert stats["mean_weight"] == 1.0 + assert stats["type"] == "uniform" + + +def test_uniform_weighter_compute_batch_weights_with_index_key(): + """Test weight computation with 'index' key in batch.""" + weighter = UniformWeighter(device=torch.device("cpu")) + batch = {"index": torch.arange(16)} + + weights, stats = weighter.compute_batch_weights(batch) + + assert weights.shape == (16,) + assert torch.allclose(weights, torch.ones(16)) + + +def test_uniform_weighter_compute_batch_weights_no_tensor_keys(): + """Test weight computation with no tensor keys (fallback to size 1).""" + weighter = UniformWeighter(device=torch.device("cpu")) + batch = {"other_key": "some_value"} + + weights, stats = weighter.compute_batch_weights(batch) + + assert weights.shape == (1,) + assert torch.allclose(weights, torch.ones(1)) + + +def test_uniform_weighter_compute_batch_weights_empty_batch_raises(): + """Test that empty batch raises ValueError.""" + weighter = UniformWeighter(device=torch.device("cpu")) + batch = {} + + with pytest.raises(ValueError, match="empty batch"): + weighter.compute_batch_weights(batch) + + +def test_uniform_weighter_compute_batch_weights_scans_all_keys(): + """Test that batch size is determined by scanning all tensor values.""" + weighter = UniformWeighter(device=torch.device("cpu")) + # Batch with non-standard key containing a tensor + batch = {"custom_tensor": torch.randn(7, 3)} + + weights, stats = weighter.compute_batch_weights(batch) + + assert weights.shape == (7,) + assert torch.allclose(weights, torch.ones(7)) + + +def test_uniform_weighter_compute_batch_weights_on_cuda(): + """Test that weights are placed on the correct device.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + weighter = UniformWeighter(device=torch.device("cuda")) + batch = {"action": torch.randn(4, 10)} + + weights, _ = weighter.compute_batch_weights(batch) + + assert weights.device.type == "cuda" + + +def test_uniform_weighter_get_stats(): + """Test get_stats returns expected structure.""" + weighter = UniformWeighter(device=torch.device("cpu")) + stats = weighter.get_stats() + + assert stats == {"type": "uniform"} + + +# ============================================================================= +# make_sample_weighter Factory Tests +# ============================================================================= + + +def test_factory_returns_none_for_none_config(): + """Test that None config returns None weighter.""" + policy = Mock() + device = torch.device("cpu") + + result = make_sample_weighter(None, policy, device) + + assert result is None + + +def test_factory_creates_uniform_weighter(): + """Test creation of UniformWeighter.""" + config = SampleWeightingConfig(type="uniform") + policy = Mock() + device = torch.device("cpu") + + weighter = make_sample_weighter(config, policy, device) + + assert isinstance(weighter, UniformWeighter) + assert isinstance(weighter, SampleWeighter) + + +def test_factory_raises_for_unknown_type(): + """Test that unknown type raises ValueError.""" + config = SampleWeightingConfig(type="unknown_type") + policy = Mock() + device = torch.device("cpu") + + with pytest.raises(ValueError, match="Unknown sample weighting type"): + make_sample_weighter(config, policy, device) + + +def test_factory_rabc_requires_chunk_size(): + """Test that RABC weighter requires chunk_size in policy config.""" + config = SampleWeightingConfig( + type="rabc", + progress_path="/path/to/progress.parquet", + ) + policy = Mock() + policy.config = Mock() + policy.config.chunk_size = None # No chunk_size + device = torch.device("cpu") + + with pytest.raises(ValueError, match="chunk_size"): + make_sample_weighter(config, policy, device) + + +def test_factory_rabc_requires_progress_path_or_dataset_info(): + """Test that RABC weighter requires progress_path or dataset info for auto-detection.""" + config = SampleWeightingConfig( + type="rabc", + progress_path=None, # No progress path + ) + policy = Mock() + policy.config = Mock() + policy.config.chunk_size = 50 + device = torch.device("cpu") + + # Should fail when no progress_path AND no dataset info + with pytest.raises(ValueError, match="progress_path"): + make_sample_weighter(config, policy, device) + + +def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet): + """Test that RABC weighter auto-detects progress_path from dataset_root.""" + config = SampleWeightingConfig( + type="rabc", + progress_path=None, # Not provided, should auto-detect + ) + policy = Mock() + policy.config = Mock() + policy.config.chunk_size = 5 + device = torch.device("cpu") + + # The parquet file is at sample_progress_parquet, get its parent directory + dataset_root = sample_progress_parquet.parent + weighter = make_sample_weighter( + config, + policy, + device, + dataset_root=str(dataset_root), + ) + + assert weighter is not None + from lerobot.rewards.sarm.rabc import RABCWeights + + assert isinstance(weighter, RABCWeights) + + +def test_factory_rabc_auto_detects_from_repo_id(): + """Test that RABC weighter constructs HF path from repo_id.""" + config = SampleWeightingConfig( + type="rabc", + progress_path=None, # Not provided, should auto-detect + ) + policy = Mock() + policy.config = Mock() + policy.config.chunk_size = 50 + device = torch.device("cpu") + + # This will construct the path but fail when trying to load (file doesn't exist) + # We just verify it doesn't raise the "progress_path required" error + with pytest.raises(Exception) as exc_info: + make_sample_weighter( + config, + policy, + device, + dataset_repo_id="test-user/test-dataset", + ) + # Should NOT be the "progress_path required" error - it should try to load the file + assert ( + "progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower() + ) + + +# ============================================================================= +# Integration Tests with RABCWeights +# ============================================================================= + + +def test_rabc_weights_is_sample_weighter(sample_progress_parquet): + """Test that RABCWeights inherits from SampleWeighter.""" + from lerobot.rewards.sarm.rabc import RABCWeights + + weighter = RABCWeights( + progress_path=sample_progress_parquet, + chunk_size=5, + head_mode="sparse", + ) + assert isinstance(weighter, SampleWeighter) + + +def test_rabc_compute_batch_weights(sample_progress_parquet): + """Test RABCWeights.compute_batch_weights returns correct structure.""" + from lerobot.rewards.sarm.rabc import RABCWeights + + weighter = RABCWeights( + progress_path=sample_progress_parquet, + chunk_size=5, + head_mode="sparse", + device=torch.device("cpu"), + ) + + batch = {"index": torch.tensor([0, 1, 2, 3])} + weights, stats = weighter.compute_batch_weights(batch) + + assert isinstance(weights, torch.Tensor) + assert weights.shape == (4,) + assert isinstance(stats, dict) + assert "mean_weight" in stats + + +def test_rabc_get_stats(sample_progress_parquet): + """Test RABCWeights.get_stats returns expected structure.""" + from lerobot.rewards.sarm.rabc import RABCWeights + + weighter = RABCWeights( + progress_path=sample_progress_parquet, + chunk_size=5, + head_mode="sparse", + ) + + stats = weighter.get_stats() + + assert stats["type"] == "rabc" + assert "num_frames" in stats + assert "chunk_size" in stats + assert stats["chunk_size"] == 5 + assert "head_mode" in stats + assert stats["head_mode"] == "sparse" + assert "delta_mean" in stats + assert "delta_std" in stats + + +def test_factory_creates_rabc_weighter(sample_progress_parquet): + """Test factory creates RABCWeights with valid config.""" + from lerobot.rewards.sarm.rabc import RABCWeights + + config = SampleWeightingConfig( + type="rabc", + progress_path=str(sample_progress_parquet), + head_mode="sparse", + kappa=0.01, + ) + policy = Mock() + policy.config = Mock() + policy.config.chunk_size = 5 + device = torch.device("cpu") + + weighter = make_sample_weighter(config, policy, device) + + assert isinstance(weighter, RABCWeights) + assert isinstance(weighter, SampleWeighter) + + +def test_rabc_weights_normalization(sample_progress_parquet): + """Test that RABCWeights normalizes weights to sum to batch_size.""" + from lerobot.rewards.sarm.rabc import RABCWeights + + weighter = RABCWeights( + progress_path=sample_progress_parquet, + chunk_size=5, + head_mode="sparse", + device=torch.device("cpu"), + ) + + batch = {"index": torch.tensor([0, 1, 2, 3])} + weights, _ = weighter.compute_batch_weights(batch) + + # Weights should be normalized to sum approximately to batch_size + batch_size = 4 + assert abs(weights.sum().item() - batch_size) < 0.1