mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Reward models refactor (#3142)
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes * refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/ * refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/ * refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py * refactor(rewards): update imports and delete old reward model locations * test(rewards): add reward model tests and update existing test imports * fix(rewards): restore full Classifier and SARM implementations * test(rewards): restore missing CUDA and mixed precision classifier processor tests * refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train * refactor(lerobot_train.py): add missing sampling weight script * linter + missing files * add testing for sampl weighter * revert some useless changes, improve typing * update docs * add automatic detection of the progress path * remove type exp * improve comment * fix: move rabc.py to rewards/sarm/ and update import paths * refactor(imports): update reward model imports to new module structure * refactor(imports): update reward model imports to reflect new module structure * refactor(imports): conditionally import pandas based on availability * feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig * refactor(policies): remove reward model branches from policy factory and __init__ * refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash * feat(train): route reward model training through rewards/factory instead of policies/factory * refactor(train): streamline reward model training logic * fix(rewards): ensure FileNotFoundError is raised for missing config_file * refactor(train): update __get_path_fields__ to include reward_model for config loading * refactor(classifier): remove redundant input normalization in predict_reward method * fix(train): raise ValueError for non-trainable reward models in train function * refactor(pretrained_rm): add model card template * refactor(tests): reward models * refactor(sarm): update reset method and remove unused action prediction methods * refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function * fix(train): raise ValueError for PEFT usage in reward model training * refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
+29
-28
@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
|
|||||||
|
|
||||||
## Inputs and Targets (What the new code expects)
|
## Inputs and Targets (What the new code expects)
|
||||||
|
|
||||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
|
||||||
|
|
||||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||||
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
|
|||||||
<hfoption id="single_stage">
|
<hfoption id="single_stage">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--visualize-only \
|
--visualize-only \
|
||||||
@@ -360,7 +360,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
|||||||
<hfoption id="dense_only">
|
<hfoption id="dense_only">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--visualize-only \
|
--visualize-only \
|
||||||
@@ -373,7 +373,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
|||||||
<hfoption id="dual">
|
<hfoption id="dual">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--visualize-only \
|
--visualize-only \
|
||||||
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
|
|||||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--head-mode sparse \
|
--head-mode sparse \
|
||||||
@@ -465,15 +465,15 @@ This script:
|
|||||||
|
|
||||||
### Step 5b: Train Policy with RA-BC
|
### Step 5b: Train Policy with RA-BC
|
||||||
|
|
||||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
--dataset.repo_id=your-username/your-dataset \
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
--policy.type=pi0 \
|
--policy.type=pi0 \
|
||||||
--use_rabc=true \
|
--sample_weighting.type=rabc \
|
||||||
--rabc_head_mode=sparse \
|
--sample_weighting.head_mode=sparse \
|
||||||
--rabc_kappa=0.01 \
|
--sample_weighting.kappa=0.01 \
|
||||||
--output_dir=outputs/train/policy_rabc \
|
--output_dir=outputs/train/policy_rabc \
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--steps=40000
|
--steps=40000
|
||||||
@@ -488,12 +488,13 @@ The training script automatically:
|
|||||||
|
|
||||||
**RA-BC Arguments:**
|
**RA-BC Arguments:**
|
||||||
|
|
||||||
| Argument | Description | Default |
|
| Argument | Description | Default |
|
||||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||||
|
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||||
|
|
||||||
### Tuning RA-BC Kappa
|
### Tuning RA-BC Kappa
|
||||||
|
|
||||||
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
|||||||
|
|
||||||
Monitor these WandB metrics during training:
|
Monitor these WandB metrics during training:
|
||||||
|
|
||||||
| Metric | Healthy Range | Problem Indicator |
|
| Metric | Healthy Range | Problem Indicator |
|
||||||
| ------------------ | ------------- | ------------------------- |
|
| ----------------------------- | ------------- | ------------------------- |
|
||||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||||
|
|
||||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||||
|
|
||||||
**Setting kappa based on your data:**
|
**Setting kappa based on your data:**
|
||||||
|
|
||||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
|
||||||
|
|
||||||
```
|
```
|
||||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||||
# Most deltas fall in range [0.01, 0.05]
|
# Most deltas fall in range [0.01, 0.05]
|
||||||
|
|
||||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||||
--rabc_kappa=0.03
|
--sample_weighting.kappa=0.03
|
||||||
|
|
||||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||||
--rabc_kappa=0.05
|
--sample_weighting.kappa=0.05
|
||||||
|
|
||||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||||
--rabc_kappa=0.07
|
--sample_weighting.kappa=0.07
|
||||||
```
|
```
|
||||||
|
|
||||||
**When RA-BC may not help:**
|
**When RA-BC may not help:**
|
||||||
@@ -550,8 +551,8 @@ accelerate launch \
|
|||||||
src/lerobot/scripts/lerobot_train.py \
|
src/lerobot/scripts/lerobot_train.py \
|
||||||
--dataset.repo_id=your-username/your-dataset \
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
--policy.type=pi0 \
|
--policy.type=pi0 \
|
||||||
--use_rabc=true \
|
--sample_weighting.type=rabc \
|
||||||
--rabc_kappa=0.01 \
|
--sample_weighting.kappa=0.01 \
|
||||||
--output_dir=outputs/train/policy_rabc \
|
--output_dir=outputs/train/policy_rabc \
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--steps=40000
|
--steps=40000
|
||||||
@@ -576,7 +577,7 @@ accelerate launch \
|
|||||||
### RA-BC
|
### RA-BC
|
||||||
|
|
||||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
from lerobot.rewards.sarm.compute_rabc_weights import (
|
||||||
generate_all_frame_indices,
|
generate_all_frame_indices,
|
||||||
interpolate_progress,
|
interpolate_progress,
|
||||||
load_sarm_resources,
|
load_sarm_resources,
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from lerobot.datasets import LeRobotDataset
|
|||||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||||
from lerobot.policies import SACConfig
|
from lerobot.policies import SACConfig
|
||||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||||
from lerobot.rl.buffer import ReplayBuffer
|
from lerobot.rl.buffer import ReplayBuffer
|
||||||
from lerobot.rl.gym_manipulator import make_robot_env
|
from lerobot.rl.gym_manipulator import make_robot_env
|
||||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets import LeRobotDataset
|
||||||
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
|
from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -22,10 +22,10 @@ def main():
|
|||||||
model_name="microsoft/resnet-18",
|
model_name="microsoft/resnet-18",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make policy, preprocessor, and optimizer
|
# Make reward model, preprocessor, and optimizer
|
||||||
policy = make_policy(config, ds_meta=dataset.meta)
|
reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
|
||||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
optimizer = config.get_optimizer_preset().build(reward_model.parameters())
|
||||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
|
||||||
|
|
||||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ def main():
|
|||||||
batch = preprocessor(batch)
|
batch = preprocessor(batch)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = reward_model.forward(batch)
|
||||||
|
|
||||||
# Backward pass and optimization
|
# Backward pass and optimization
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -58,8 +58,8 @@ def main():
|
|||||||
|
|
||||||
print("Training finished!")
|
print("Training finished!")
|
||||||
|
|
||||||
# You can now save the trained policy.
|
# You can now save the trained reward model.
|
||||||
policy.push_to_hub(classifier_id)
|
reward_model.push_to_hub(classifier_id)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -41,8 +41,12 @@ def cfg_to_group(
|
|||||||
return tag
|
return tag
|
||||||
return tag[:max_tag_length]
|
return tag[:max_tag_length]
|
||||||
|
|
||||||
|
if cfg.is_reward_model_training:
|
||||||
|
trainable_tag = f"reward_model:{cfg.reward_model.type}"
|
||||||
|
else:
|
||||||
|
trainable_tag = f"policy:{cfg.policy.type}"
|
||||||
lst = [
|
lst = [
|
||||||
f"policy:{cfg.policy.type}",
|
trainable_tag,
|
||||||
f"seed:{cfg.seed}",
|
f"seed:{cfg.seed}",
|
||||||
]
|
]
|
||||||
if cfg.dataset is not None:
|
if cfg.dataset is not None:
|
||||||
|
|||||||
@@ -0,0 +1,163 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import builtins
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
import draccus
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from huggingface_hub.constants import CONFIG_NAME
|
||||||
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
|
from lerobot.optim.optimizers import OptimizerConfig
|
||||||
|
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||||
|
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||||
|
from lerobot.utils.hub import HubMixin
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="RewardModelConfig")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||||
|
"""Base configuration for reward models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents
|
||||||
|
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||||
|
output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents
|
||||||
|
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Reuses PolicyFeature
|
||||||
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
device: str | None = None
|
||||||
|
|
||||||
|
pretrained_path: str | None = None
|
||||||
|
|
||||||
|
push_to_hub: bool = False
|
||||||
|
repo_id: str | None = None
|
||||||
|
|
||||||
|
# Hub metadata
|
||||||
|
license: str | None = None
|
||||||
|
tags: list[str] | None = None
|
||||||
|
private: bool | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if not self.device or not is_torch_device_available(self.device):
|
||||||
|
auto_device = auto_select_torch_device()
|
||||||
|
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||||
|
self.device = auto_device.type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
choice_name = self.get_choice_name(self.__class__)
|
||||||
|
if not isinstance(choice_name, str):
|
||||||
|
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||||
|
return choice_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
|
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||||
|
draccus.dump(self, f, indent=4)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls: builtins.type[T],
|
||||||
|
pretrained_name_or_path: str | Path,
|
||||||
|
*,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool | None = None,
|
||||||
|
proxies: dict[Any, Any] | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
cache_dir: str | Path | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
**reward_kwargs: Any,
|
||||||
|
) -> T:
|
||||||
|
model_id = str(pretrained_name_or_path)
|
||||||
|
config_file: str | None = None
|
||||||
|
if Path(model_id).is_dir():
|
||||||
|
if CONFIG_NAME in os.listdir(model_id):
|
||||||
|
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||||
|
else:
|
||||||
|
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
config_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=CONFIG_NAME,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
except HfHubHTTPError as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if config_file is None:
|
||||||
|
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||||
|
|
||||||
|
# HACK: Parse the original config to get the config subclass, so that we can
|
||||||
|
# apply cli overrides.
|
||||||
|
with draccus.config_type("json"):
|
||||||
|
orig_config = draccus.parse(cls, config_file, args=[])
|
||||||
|
|
||||||
|
with open(config_file) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
config.pop("type", None)
|
||||||
|
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
config_file = f.name
|
||||||
|
|
||||||
|
cli_overrides = reward_kwargs.pop("cli_overrides", [])
|
||||||
|
with draccus.config_type("json"):
|
||||||
|
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
|
||||||
@@ -26,9 +26,11 @@ from lerobot import envs
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||||
from lerobot.utils.hub import HubMixin
|
from lerobot.utils.hub import HubMixin
|
||||||
|
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||||
|
|
||||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||||
from .policies import PreTrainedConfig
|
from .policies import PreTrainedConfig
|
||||||
|
from .rewards import RewardModelConfig
|
||||||
|
|
||||||
TRAIN_CONFIG_NAME = "train_config.json"
|
TRAIN_CONFIG_NAME = "train_config.json"
|
||||||
|
|
||||||
@@ -38,6 +40,7 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
dataset: DatasetConfig
|
dataset: DatasetConfig
|
||||||
env: envs.EnvConfig | None = None
|
env: envs.EnvConfig | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
|
reward_model: RewardModelConfig | None = None
|
||||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||||
output_dir: Path | None = None
|
output_dir: Path | None = None
|
||||||
@@ -72,27 +75,41 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||||
peft: PeftConfig | None = None
|
peft: PeftConfig | None = None
|
||||||
|
|
||||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
# Sample weighting configuration (e.g., for RA-BC training)
|
||||||
use_rabc: bool = False # Enable reward-weighted training
|
sample_weighting: SampleWeightingConfig | None = None
|
||||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
|
||||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
|
||||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
|
||||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
|
||||||
|
|
||||||
# Rename map for the observation to override the image and state keys
|
# Rename map for the observation to override the image and state keys
|
||||||
rename_map: dict[str, str] = field(default_factory=dict)
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
checkpoint_path: Path | None = field(init=False, default=None)
|
checkpoint_path: Path | None = field(init=False, default=None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_reward_model_training(self) -> bool:
|
||||||
|
"""True when the config targets a reward model rather than a policy."""
|
||||||
|
return self.reward_model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
|
||||||
|
"""Return whichever config (policy or reward_model) is active."""
|
||||||
|
if self.is_reward_model_training:
|
||||||
|
return self.reward_model # type: ignore[return-value]
|
||||||
|
return self.policy # type: ignore[return-value]
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
if policy_path:
|
reward_model_path = parser.get_path_arg("reward_model")
|
||||||
# Only load the policy config
|
|
||||||
|
if reward_model_path:
|
||||||
|
cli_overrides = parser.get_cli_overrides("reward_model")
|
||||||
|
self.reward_model = RewardModelConfig.from_pretrained(
|
||||||
|
reward_model_path, cli_overrides=cli_overrides
|
||||||
|
)
|
||||||
|
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||||
|
elif policy_path:
|
||||||
cli_overrides = parser.get_cli_overrides("policy")
|
cli_overrides = parser.get_cli_overrides("policy")
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = Path(policy_path)
|
self.policy.pretrained_path = Path(policy_path)
|
||||||
elif self.resume:
|
elif self.resume:
|
||||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
|
||||||
config_path = parser.parse_arg("config_path")
|
config_path = parser.parse_arg("config_path")
|
||||||
if not config_path:
|
if not config_path:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -108,18 +125,22 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
policy_dir = Path(config_path).parent
|
policy_dir = Path(config_path).parent
|
||||||
if self.policy is not None:
|
if self.policy is not None:
|
||||||
self.policy.pretrained_path = policy_dir
|
self.policy.pretrained_path = policy_dir
|
||||||
|
if self.reward_model is not None:
|
||||||
|
self.reward_model.pretrained_path = str(policy_dir)
|
||||||
self.checkpoint_path = policy_dir.parent
|
self.checkpoint_path = policy_dir.parent
|
||||||
|
|
||||||
if self.policy is None:
|
if self.policy is None and self.reward_model is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
"Neither policy nor reward_model is configured. "
|
||||||
|
"Please specify one with `--policy.path` or `--reward_model.path`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
active_cfg = self.trainable_config
|
||||||
if not self.job_name:
|
if not self.job_name:
|
||||||
if self.env is None:
|
if self.env is None:
|
||||||
self.job_name = f"{self.policy.type}"
|
self.job_name = f"{active_cfg.type}"
|
||||||
else:
|
else:
|
||||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
self.job_name = f"{self.env.type}_{active_cfg.type}"
|
||||||
|
|
||||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||||
raise FileExistsError(
|
raise FileExistsError(
|
||||||
@@ -137,26 +158,16 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||||
elif self.use_policy_training_preset and not self.resume:
|
elif self.use_policy_training_preset and not self.resume:
|
||||||
self.optimizer = self.policy.get_optimizer_preset()
|
self.optimizer = active_cfg.get_optimizer_preset()
|
||||||
self.scheduler = self.policy.get_scheduler_preset()
|
self.scheduler = active_cfg.get_scheduler_preset()
|
||||||
|
|
||||||
if self.policy.push_to_hub and not self.policy.repo_id:
|
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||||
raise ValueError(
|
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_rabc and not self.rabc_progress_path:
|
|
||||||
# Auto-detect from dataset path
|
|
||||||
repo_id = self.dataset.repo_id
|
|
||||||
if self.dataset.root:
|
|
||||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
|
||||||
else:
|
|
||||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
"""Keys for draccus pretrained-path loading."""
|
||||||
return ["policy"]
|
return ["policy", "reward_model"]
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from pprint import pformat
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs import PreTrainedConfig
|
from lerobot.configs import PreTrainedConfig
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.transforms import ImageTransforms
|
from lerobot.transforms import ImageTransforms
|
||||||
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
||||||
@@ -30,12 +31,14 @@ from .streaming_dataset import StreamingLeRobotDataset
|
|||||||
|
|
||||||
|
|
||||||
def resolve_delta_timestamps(
|
def resolve_delta_timestamps(
|
||||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata
|
||||||
) -> dict[str, list] | None:
|
) -> dict[str, list] | None:
|
||||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both
|
||||||
|
``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the
|
||||||
|
``{observation,action,reward}_delta_indices`` properties used below.
|
||||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||||
delta_timestamps against.
|
delta_timestamps against.
|
||||||
|
|
||||||
@@ -82,7 +85,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
ds_meta = LeRobotDatasetMetadata(
|
ds_meta = LeRobotDatasetMetadata(
|
||||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||||
)
|
)
|
||||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta)
|
||||||
if not cfg.dataset.streaming:
|
if not cfg.dataset.streaming:
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset.repo_id,
|
cfg.dataset.repo_id,
|
||||||
|
|||||||
@@ -24,8 +24,6 @@ from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
|||||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||||
from .sac.configuration_sac import SACConfig as SACConfig
|
from .sac.configuration_sac import SACConfig as SACConfig
|
||||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
|
||||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
|
||||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||||
from .utils import make_robot_action, prepare_observation_for_inference
|
from .utils import make_robot_action, prepare_observation_for_inference
|
||||||
@@ -46,9 +44,7 @@ __all__ = [
|
|||||||
"PI0Config",
|
"PI0Config",
|
||||||
"PI0FastConfig",
|
"PI0FastConfig",
|
||||||
"PI05Config",
|
"PI05Config",
|
||||||
"RewardClassifierConfig",
|
|
||||||
"SACConfig",
|
"SACConfig",
|
||||||
"SARMConfig",
|
|
||||||
"SmolVLAConfig",
|
"SmolVLAConfig",
|
||||||
"TDMPCConfig",
|
"TDMPCConfig",
|
||||||
"VQBeTConfig",
|
"VQBeTConfig",
|
||||||
|
|||||||
@@ -52,8 +52,6 @@ from .pi0.configuration_pi0 import PI0Config
|
|||||||
from .pi05.configuration_pi05 import PI05Config
|
from .pi05.configuration_pi05 import PI05Config
|
||||||
from .pretrained import PreTrainedPolicy
|
from .pretrained import PreTrainedPolicy
|
||||||
from .sac.configuration_sac import SACConfig
|
from .sac.configuration_sac import SACConfig
|
||||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig
|
|
||||||
from .sarm.configuration_sarm import SARMConfig
|
|
||||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from .utils import validate_visual_features_consistency
|
from .utils import validate_visual_features_consistency
|
||||||
@@ -89,7 +87,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
|
||||||
Returns:
|
Returns:
|
||||||
The policy class corresponding to the given name.
|
The policy class corresponding to the given name.
|
||||||
|
|
||||||
@@ -132,18 +130,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from .sac.modeling_sac import SACPolicy
|
from .sac.modeling_sac import SACPolicy
|
||||||
|
|
||||||
return SACPolicy
|
return SACPolicy
|
||||||
elif name == "reward_classifier":
|
|
||||||
from .sac.reward_model.modeling_classifier import Classifier
|
|
||||||
|
|
||||||
return Classifier
|
|
||||||
elif name == "smolvla":
|
elif name == "smolvla":
|
||||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||||
|
|
||||||
return SmolVLAPolicy
|
return SmolVLAPolicy
|
||||||
elif name == "sarm":
|
|
||||||
from .sarm.modeling_sarm import SARMRewardModel
|
|
||||||
|
|
||||||
return SARMRewardModel
|
|
||||||
elif name == "groot":
|
elif name == "groot":
|
||||||
from .groot.modeling_groot import GrootPolicy
|
from .groot.modeling_groot import GrootPolicy
|
||||||
|
|
||||||
@@ -173,7 +163,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
Args:
|
Args:
|
||||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||||
"smolvla", "reward_classifier", "wall_x".
|
"smolvla", "wall_x".
|
||||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -200,8 +190,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return SACConfig(**kwargs)
|
return SACConfig(**kwargs)
|
||||||
elif policy_type == "smolvla":
|
elif policy_type == "smolvla":
|
||||||
return SmolVLAConfig(**kwargs)
|
return SmolVLAConfig(**kwargs)
|
||||||
elif policy_type == "reward_classifier":
|
|
||||||
return RewardClassifierConfig(**kwargs)
|
|
||||||
elif policy_type == "groot":
|
elif policy_type == "groot":
|
||||||
return GrootConfig(**kwargs)
|
return GrootConfig(**kwargs)
|
||||||
elif policy_type == "xvla":
|
elif policy_type == "xvla":
|
||||||
@@ -378,14 +366,6 @@ def make_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
|
||||||
from .sac.reward_model.processor_classifier import make_classifier_processor
|
|
||||||
|
|
||||||
processors = make_classifier_processor(
|
|
||||||
config=policy_cfg,
|
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||||
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||||
|
|
||||||
@@ -394,14 +374,6 @@ def make_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(policy_cfg, SARMConfig):
|
|
||||||
from .sarm.processor_sarm import make_sarm_pre_post_processors
|
|
||||||
|
|
||||||
processors = make_sarm_pre_post_processors(
|
|
||||||
config=policy_cfg,
|
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
|
||||||
dataset_meta=kwargs.get("dataset_meta"),
|
|
||||||
)
|
|
||||||
elif isinstance(policy_cfg, GrootConfig):
|
elif isinstance(policy_cfg, GrootConfig):
|
||||||
from .groot.processor_groot import make_groot_pre_post_processors
|
from .groot.processor_groot import make_groot_pre_post_processors
|
||||||
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
../../../../docs/source/policy_sarm_README.md
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from .configuration_sarm import SARMConfig
|
|
||||||
from .modeling_sarm import SARMRewardModel
|
|
||||||
|
|
||||||
__all__ = ["SARMConfig", "SARMRewardModel"]
|
|
||||||
@@ -557,7 +557,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Initializes the reward classifier model after the dataclass is created."""
|
"""Initializes the reward classifier model after the dataclass is created."""
|
||||||
if self.pretrained_path is not None:
|
if self.pretrained_path is not None:
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
||||||
self.reward_classifier.to(self.device)
|
self.reward_classifier.to(self.device)
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||||
|
from .factory import (
|
||||||
|
get_reward_model_class as get_reward_model_class,
|
||||||
|
make_reward_model as make_reward_model,
|
||||||
|
make_reward_model_config as make_reward_model_config,
|
||||||
|
make_reward_pre_post_processors as make_reward_pre_post_processors,
|
||||||
|
)
|
||||||
|
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
|
||||||
|
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Configuration classes
|
||||||
|
"RewardClassifierConfig",
|
||||||
|
"SARMConfig",
|
||||||
|
# Base class
|
||||||
|
"PreTrainedRewardModel",
|
||||||
|
# Factory functions
|
||||||
|
"get_reward_model_class",
|
||||||
|
"make_reward_model",
|
||||||
|
"make_reward_model_config",
|
||||||
|
"make_reward_pre_post_processors",
|
||||||
|
]
|
||||||
+4
-5
@@ -1,5 +1,3 @@
|
|||||||
# !/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -15,14 +13,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.configs import NormalizationMode, PreTrainedConfig
|
from lerobot.configs import NormalizationMode
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig
|
from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig
|
||||||
from lerobot.utils.constants import OBS_IMAGE
|
from lerobot.utils.constants import OBS_IMAGE
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
@RewardModelConfig.register_subclass(name="reward_classifier")
|
||||||
@dataclass
|
@dataclass
|
||||||
class RewardClassifierConfig(PreTrainedConfig):
|
class RewardClassifierConfig(RewardModelConfig):
|
||||||
"""Configuration for the Reward Classifier model."""
|
"""Configuration for the Reward Classifier model."""
|
||||||
|
|
||||||
name: str = "reward_classifier"
|
name: str = "reward_classifier"
|
||||||
+13
-35
@@ -1,5 +1,3 @@
|
|||||||
# !/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -19,11 +17,10 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||||
|
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||||
|
|
||||||
from ...pretrained import PreTrainedPolicy
|
|
||||||
from .configuration_classifier import RewardClassifierConfig
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifierOutput:
|
class ClassifierOutput:
|
||||||
"""Wrapper for classifier outputs with additional metadata."""
|
"""Wrapper for classifier outputs with additional metadata."""
|
||||||
@@ -99,7 +96,7 @@ class SpatialLearnedEmbeddings(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Classifier(PreTrainedPolicy):
|
class Classifier(PreTrainedRewardModel):
|
||||||
"""Image classifier built on top of a pre-trained encoder."""
|
"""Image classifier built on top of a pre-trained encoder."""
|
||||||
|
|
||||||
name = "reward_classifier"
|
name = "reward_classifier"
|
||||||
@@ -235,6 +232,16 @@ class Classifier(PreTrainedPolicy):
|
|||||||
|
|
||||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||||
|
|
||||||
|
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Returns 1.0 for success, 0.0 for failure based on image observations."""
|
||||||
|
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||||
|
output = self.predict(images)
|
||||||
|
|
||||||
|
if self.config.num_classes == 2:
|
||||||
|
return (output.probabilities > 0.5).float()
|
||||||
|
else:
|
||||||
|
return torch.argmax(output.probabilities, dim=1).float()
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||||
"""Standard forward pass for training compatible with train.py."""
|
"""Standard forward pass for training compatible with train.py."""
|
||||||
# Extract images and labels
|
# Extract images and labels
|
||||||
@@ -269,10 +276,6 @@ class Classifier(PreTrainedPolicy):
|
|||||||
|
|
||||||
def predict_reward(self, batch, threshold=0.5):
|
def predict_reward(self, batch, threshold=0.5):
|
||||||
"""Eval method. Returns predicted reward with the decision threshold as argument."""
|
"""Eval method. Returns predicted reward with the decision threshold as argument."""
|
||||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
|
||||||
batch = self.normalize_inputs(batch)
|
|
||||||
batch = self.normalize_targets(batch)
|
|
||||||
|
|
||||||
# Extract images from batch dict
|
# Extract images from batch dict
|
||||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||||
|
|
||||||
@@ -282,28 +285,3 @@ class Classifier(PreTrainedPolicy):
|
|||||||
return (probs > threshold).float()
|
return (probs > threshold).float()
|
||||||
else:
|
else:
|
||||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||||
|
|
||||||
def get_optim_params(self):
|
|
||||||
"""Return optimizer parameters for the policy."""
|
|
||||||
return self.parameters()
|
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
||||||
"""
|
|
||||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
|
||||||
The reward classifier is not an actor and does not select actions.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Reward classifiers do not select actions")
|
|
||||||
|
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
|
||||||
"""
|
|
||||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
|
||||||
The reward classifier is not an actor and does not produce action chunks.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Reward classifiers do not predict action chunks")
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""
|
|
||||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
|
||||||
The reward classifier is not an actor and does not select actions.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
+1
-6
@@ -1,5 +1,3 @@
|
|||||||
# !/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -27,8 +25,7 @@ from lerobot.processor import (
|
|||||||
policy_action_to_transition,
|
policy_action_to_transition,
|
||||||
transition_to_policy_action,
|
transition_to_policy_action,
|
||||||
)
|
)
|
||||||
|
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||||
from .configuration_classifier import RewardClassifierConfig
|
|
||||||
|
|
||||||
|
|
||||||
def make_classifier_processor(
|
def make_classifier_processor(
|
||||||
@@ -52,8 +49,6 @@ def make_classifier_processor(
|
|||||||
Args:
|
Args:
|
||||||
config: The configuration object for the RewardClassifier.
|
config: The configuration object for the RewardClassifier.
|
||||||
dataset_stats: A dictionary of statistics for normalization.
|
dataset_stats: A dictionary of statistics for normalization.
|
||||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
|
||||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||||
@@ -0,0 +1,238 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||||
|
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||||
|
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||||
|
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||||
|
|
||||||
|
|
||||||
|
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||||
|
"""
|
||||||
|
Retrieves a reward model class by its registered name.
|
||||||
|
|
||||||
|
This function uses dynamic imports to avoid loading all reward model classes into
|
||||||
|
memory at once, improving startup time and reducing dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the reward model. Supported names are "reward_classifier",
|
||||||
|
"sarm".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reward model class corresponding to the given name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the reward model name is not recognized.
|
||||||
|
"""
|
||||||
|
if name == "reward_classifier":
|
||||||
|
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
|
return Classifier
|
||||||
|
elif name == "sarm":
|
||||||
|
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||||
|
|
||||||
|
return SARMRewardModel
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return _get_reward_model_cls_from_name(name=name)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Reward model type '{name}' is not available.") from e
|
||||||
|
|
||||||
|
|
||||||
|
def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||||
|
"""
|
||||||
|
Instantiates a reward model configuration object based on the reward type.
|
||||||
|
|
||||||
|
This factory function simplifies the creation of reward model configuration objects
|
||||||
|
by mapping a string identifier to the corresponding config class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reward_type: The type of the reward model. Supported types include
|
||||||
|
"reward_classifier", "sarm".
|
||||||
|
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of a `RewardModelConfig` subclass.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the `reward_type` is not recognized.
|
||||||
|
"""
|
||||||
|
if reward_type == "reward_classifier":
|
||||||
|
return RewardClassifierConfig(**kwargs)
|
||||||
|
elif reward_type == "sarm":
|
||||||
|
return SARMConfig(**kwargs)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||||
|
return config_cls(**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Reward model type '{reward_type}' is not available.") from e
|
||||||
|
|
||||||
|
|
||||||
|
def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel:
|
||||||
|
"""
|
||||||
|
Instantiate a reward model from its configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The configuration for the reward model to be created. If
|
||||||
|
`cfg.pretrained_path` is set, the model will be loaded with weights
|
||||||
|
from that path.
|
||||||
|
**kwargs: Additional keyword arguments forwarded to the model constructor
|
||||||
|
(e.g., ``dataset_stats``, ``dataset_meta``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instantiated and device-placed reward model.
|
||||||
|
"""
|
||||||
|
reward_cls = get_reward_model_class(cfg.type)
|
||||||
|
|
||||||
|
kwargs["config"] = cfg
|
||||||
|
|
||||||
|
if cfg.pretrained_path:
|
||||||
|
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||||
|
reward_model = reward_cls.from_pretrained(**kwargs)
|
||||||
|
else:
|
||||||
|
reward_model = reward_cls(**kwargs)
|
||||||
|
|
||||||
|
reward_model.to(cfg.device)
|
||||||
|
assert isinstance(reward_model, torch.nn.Module)
|
||||||
|
|
||||||
|
return reward_model
|
||||||
|
|
||||||
|
|
||||||
|
def make_reward_pre_post_processors(
|
||||||
|
reward_cfg: RewardModelConfig,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Create pre- and post-processor pipelines for a given reward model.
|
||||||
|
|
||||||
|
Each reward model type has a dedicated factory function for its processors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reward_cfg: The configuration of the reward model for which to create processors.
|
||||||
|
**kwargs: Additional keyword arguments passed to the processor factory
|
||||||
|
(e.g., ``dataset_stats``, ``dataset_meta``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a processor factory is not implemented for the given reward
|
||||||
|
model configuration type.
|
||||||
|
"""
|
||||||
|
# Create a new processor based on reward model type
|
||||||
|
if isinstance(reward_cfg, RewardClassifierConfig):
|
||||||
|
from lerobot.rewards.classifier.processor_classifier import make_classifier_processor
|
||||||
|
|
||||||
|
return make_classifier_processor(
|
||||||
|
config=reward_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(reward_cfg, SARMConfig):
|
||||||
|
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||||
|
|
||||||
|
return make_sarm_pre_post_processors(
|
||||||
|
config=reward_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
dataset_meta=kwargs.get("dataset_meta"),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
processors = _make_processors_from_reward_model_config(
|
||||||
|
config=reward_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Processor for reward model type '{reward_cfg.type}' is not implemented."
|
||||||
|
) from e
|
||||||
|
return processors
|
||||||
|
|
||||||
|
|
||||||
|
def _get_reward_model_cls_from_name(name: str) -> type[PreTrainedRewardModel]:
|
||||||
|
"""Get reward model class from its registered name using dynamic imports.
|
||||||
|
|
||||||
|
This is used as a helper function to import reward models from 3rd party lerobot
|
||||||
|
plugins.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the reward model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reward model class corresponding to the given name.
|
||||||
|
"""
|
||||||
|
if name not in RewardModelConfig.get_known_choices():
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown reward model name '{name}'. "
|
||||||
|
f"Available reward models: {RewardModelConfig.get_known_choices()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
config_cls = RewardModelConfig.get_choice_class(name)
|
||||||
|
config_cls_name = config_cls.__name__
|
||||||
|
|
||||||
|
model_name = config_cls_name.removesuffix("Config")
|
||||||
|
if model_name == config_cls_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"The config class name '{config_cls_name}' does not follow the expected naming convention. "
|
||||||
|
f"Make sure it ends with 'Config'!"
|
||||||
|
)
|
||||||
|
|
||||||
|
cls_name = model_name + "RewardModel"
|
||||||
|
module_path = config_cls.__module__.replace("configuration_", "modeling_")
|
||||||
|
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
reward_cls = getattr(module, cls_name)
|
||||||
|
return reward_cls
|
||||||
|
|
||||||
|
|
||||||
|
def _make_processors_from_reward_model_config(
|
||||||
|
config: RewardModelConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[Any, Any]:
|
||||||
|
"""Create pre- and post-processors from a reward model configuration using dynamic imports.
|
||||||
|
|
||||||
|
This is used as a helper function to import processor factories from 3rd party
|
||||||
|
lerobot reward model plugins.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The reward model configuration object.
|
||||||
|
dataset_stats: Dataset statistics for normalization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||||
|
"""
|
||||||
|
reward_type = config.type
|
||||||
|
function_name = f"make_{reward_type}_pre_post_processors"
|
||||||
|
module_path = config.__class__.__module__.replace("configuration_", "processor_")
|
||||||
|
logging.debug(
|
||||||
|
f"Instantiating reward pre/post processors using function '{function_name}' "
|
||||||
|
f"from module '{module_path}'"
|
||||||
|
)
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
function = getattr(module, function_name)
|
||||||
|
return function(config, dataset_stats=dataset_stats)
|
||||||
@@ -0,0 +1,244 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import builtins
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from importlib.resources import files
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import TYPE_CHECKING, Any, TypeVar
|
||||||
|
|
||||||
|
import packaging
|
||||||
|
import safetensors
|
||||||
|
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
|
||||||
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
|
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
|
from lerobot.utils.hub import HubMixin
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="PreTrainedRewardModel")
|
||||||
|
|
||||||
|
|
||||||
|
class PreTrainedRewardModel(nn.Module, HubMixin, abc.ABC):
|
||||||
|
"""Base class for reward models."""
|
||||||
|
|
||||||
|
config_class: None
|
||||||
|
name: None
|
||||||
|
|
||||||
|
def __init__(self, config: RewardModelConfig, *inputs, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(config, RewardModelConfig):
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
||||||
|
"`RewardModelConfig`. To create a model from a pretrained model use "
|
||||||
|
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
if not getattr(cls, "config_class", None):
|
||||||
|
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
||||||
|
if not getattr(cls, "name", None):
|
||||||
|
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||||
|
|
||||||
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
|
self.config._save_pretrained(save_directory)
|
||||||
|
model_to_save = self.module if hasattr(self, "module") else self
|
||||||
|
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls: builtins.type[T],
|
||||||
|
pretrained_name_or_path: str | Path,
|
||||||
|
*,
|
||||||
|
config: RewardModelConfig | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool | None = None,
|
||||||
|
proxies: dict | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
cache_dir: str | Path | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
strict: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
The reward model is set in evaluation mode by default using `reward.eval()` (dropout modules are
|
||||||
|
deactivated). To train it, you should first set it back in training mode with `reward.train()`.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = RewardModelConfig.from_pretrained(
|
||||||
|
pretrained_name_or_path=pretrained_name_or_path,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
model_id = str(pretrained_name_or_path)
|
||||||
|
instance = cls(config, **kwargs)
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
print("Loading weights from local directory")
|
||||||
|
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||||
|
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
model_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=SAFETENSORS_SINGLE_FILE,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
|
||||||
|
except HfHubHTTPError as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
reward.to(config.device)
|
||||||
|
reward.eval()
|
||||||
|
return reward
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||||
|
# Create base kwargs
|
||||||
|
kwargs = {"strict": strict}
|
||||||
|
|
||||||
|
# Add device parameter for newer versions that support it
|
||||||
|
if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
|
||||||
|
kwargs["device"] = map_location
|
||||||
|
|
||||||
|
# Load the model with appropriate kwargs
|
||||||
|
missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
|
||||||
|
if missing_keys:
|
||||||
|
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
|
||||||
|
if unexpected_keys:
|
||||||
|
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
|
||||||
|
|
||||||
|
# For older versions, manually move to device if needed
|
||||||
|
if "device" not in kwargs and map_location != "cpu":
|
||||||
|
logging.warning(
|
||||||
|
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
||||||
|
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
||||||
|
" This leads to a slower loading time."
|
||||||
|
" Please update safetensors to version 0.4.3 or above for improved performance."
|
||||||
|
)
|
||||||
|
model.to(map_location)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_optim_params(self):
|
||||||
|
"""
|
||||||
|
Returns the reward-model-specific parameters dict to be passed on to the optimizer.
|
||||||
|
"""
|
||||||
|
return self.parameters()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset any internal state."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Compute a scalar reward signal for a batch of observations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Dictionary containing at minimum observation tensors.
|
||||||
|
May also contain "action", "next_observation.*", etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape ``(batch_size,)`` with reward values.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||||
|
"""Training forward pass — override for trainable reward models."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} is not trainable. Only use compute_reward() for inference."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_trainable(self) -> bool:
|
||||||
|
"""Whether this reward model can be trained via ``lerobot-train``.
|
||||||
|
|
||||||
|
Trainable reward models override :meth:`forward`; zero-shot models
|
||||||
|
inherit the base implementation that raises ``NotImplementedError``.
|
||||||
|
"""
|
||||||
|
return type(self).forward is not PreTrainedRewardModel.forward
|
||||||
|
|
||||||
|
def push_model_to_hub(self, cfg: "TrainPipelineConfig"):
|
||||||
|
api = HfApi()
|
||||||
|
repo_id = api.create_repo(
|
||||||
|
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
|
||||||
|
).repo_id
|
||||||
|
|
||||||
|
# Push the files to the repo in a single commit
|
||||||
|
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||||
|
saved_path = Path(tmp) / repo_id
|
||||||
|
|
||||||
|
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||||
|
|
||||||
|
card = self.generate_model_card(
|
||||||
|
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||||
|
)
|
||||||
|
card.save(str(saved_path / "README.md"))
|
||||||
|
|
||||||
|
cfg.save_pretrained(saved_path) # Calls _save_pretrained and stores train config
|
||||||
|
|
||||||
|
commit_info = api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="model",
|
||||||
|
folder_path=saved_path,
|
||||||
|
commit_message="Upload reward model weights, train config and readme",
|
||||||
|
allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.md"],
|
||||||
|
ignore_patterns=["*.tmp", "*.log"],
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||||
|
|
||||||
|
def generate_model_card(
|
||||||
|
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||||
|
) -> ModelCard:
|
||||||
|
card_data = ModelCardData(
|
||||||
|
license=license or "apache-2.0",
|
||||||
|
library_name="lerobot",
|
||||||
|
pipeline_tag="robotics",
|
||||||
|
tags=list(set(tags or []).union({"robotics", "lerobot", "reward-model", model_type})),
|
||||||
|
model_name=model_type,
|
||||||
|
datasets=dataset_repo_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
template_card = (
|
||||||
|
files("lerobot.templates")
|
||||||
|
.joinpath("lerobot_rewardmodel_modelcard_template.md")
|
||||||
|
.read_text(encoding="utf-8")
|
||||||
|
)
|
||||||
|
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||||
|
card.validate()
|
||||||
|
return card
|
||||||
+8
-9
@@ -25,18 +25,18 @@ need ~num_frames/30 queries instead of one per frame (~30x speedup).
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
# Full RA-BC computation with visualizations
|
# Full RA-BC computation with visualizations
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||||
--reward-model-path <USER>/sarm_single_uni4
|
--reward-model-path <USER>/sarm_single_uni4
|
||||||
|
|
||||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||||
--stride 5
|
--stride 5
|
||||||
|
|
||||||
# Visualize predictions only (no RA-BC computation)
|
# Visualize predictions only (no RA-BC computation)
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||||
--visualize-only \\
|
--visualize-only \\
|
||||||
@@ -58,10 +58,9 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets import LeRobotDataset
|
||||||
|
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||||
from .modeling_sarm import SARMRewardModel
|
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||||
from .processor_sarm import make_sarm_pre_post_processors
|
from lerobot.rewards.sarm.sarm_utils import normalize_stage_tau
|
||||||
from .sarm_utils import normalize_stage_tau
|
|
||||||
|
|
||||||
|
|
||||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||||
@@ -713,12 +712,12 @@ def main():
|
|||||||
epilog="""
|
epilog="""
|
||||||
Examples:
|
Examples:
|
||||||
# Full RA-BC computation with visualizations
|
# Full RA-BC computation with visualizations
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||||
--reward-model-path <USER>/sarm_single_uni4
|
--reward-model-path <USER>/sarm_single_uni4
|
||||||
|
|
||||||
# Visualize predictions only (no RA-BC computation)
|
# Visualize predictions only (no RA-BC computation)
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||||
--visualize-only \\
|
--visualize-only \\
|
||||||
+4
-6
@@ -1,5 +1,3 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||||
# and The HuggingFace Inc. team. All rights reserved.
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
@@ -22,14 +20,15 @@ Paper: https://arxiv.org/abs/2509.25358
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("sarm")
|
@RewardModelConfig.register_subclass("sarm")
|
||||||
@dataclass
|
@dataclass
|
||||||
class SARMConfig(PreTrainedConfig):
|
class SARMConfig(RewardModelConfig):
|
||||||
"""Configuration class for SARM (Stage-Aware Reward Modeling).
|
"""Configuration class for SARM (Stage-Aware Reward Modeling).
|
||||||
|
|
||||||
Supports three annotation modes:
|
Supports three annotation modes:
|
||||||
@@ -110,7 +109,6 @@ class SARMConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
|
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
|
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
|
||||||
+23
-17
@@ -1,5 +1,3 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||||
# and The HuggingFace Inc. team. All rights reserved.
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
@@ -34,14 +32,13 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.utils.constants import OBS_STR
|
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||||
|
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||||
from ..pretrained import PreTrainedPolicy
|
from lerobot.rewards.sarm.sarm_utils import (
|
||||||
from .configuration_sarm import SARMConfig
|
|
||||||
from .sarm_utils import (
|
|
||||||
normalize_stage_tau,
|
normalize_stage_tau,
|
||||||
pad_state_to_max_dim,
|
pad_state_to_max_dim,
|
||||||
)
|
)
|
||||||
|
from lerobot.utils.constants import OBS_STR
|
||||||
|
|
||||||
|
|
||||||
class StageTransformer(nn.Module):
|
class StageTransformer(nn.Module):
|
||||||
@@ -353,7 +350,7 @@ def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
|
|||||||
return stage_onehot
|
return stage_onehot
|
||||||
|
|
||||||
|
|
||||||
class SARMRewardModel(PreTrainedPolicy):
|
class SARMRewardModel(PreTrainedRewardModel):
|
||||||
"""
|
"""
|
||||||
SARM Reward Model for stage-aware task completion rewards.
|
SARM Reward Model for stage-aware task completion rewards.
|
||||||
|
|
||||||
@@ -471,6 +468,23 @@ class SARMRewardModel(PreTrainedPolicy):
|
|||||||
self.subtask_model.to(device)
|
self.subtask_model.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Compute dense progress reward in [0, 1] from batch.
|
||||||
|
|
||||||
|
Expects batch to contain:
|
||||||
|
- "observation_features" or video embeddings: (B, T, 512)
|
||||||
|
- "language_embedding" or text embeddings: (B, 512)
|
||||||
|
- optionally "observation.state": (B, T, state_dim)
|
||||||
|
"""
|
||||||
|
text_emb = batch.get("language_embedding", batch.get("text_features"))
|
||||||
|
video_emb = batch.get("observation_features", batch.get("video_features"))
|
||||||
|
state = batch.get("observation.state", batch.get("state_features"))
|
||||||
|
|
||||||
|
rewards = self.calculate_rewards(text_emb, video_emb, state)
|
||||||
|
if isinstance(rewards, np.ndarray):
|
||||||
|
rewards = torch.from_numpy(rewards).float()
|
||||||
|
return rewards
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def calculate_rewards(
|
def calculate_rewards(
|
||||||
self,
|
self,
|
||||||
@@ -631,17 +645,9 @@ class SARMRewardModel(PreTrainedPolicy):
|
|||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
"""SARM has no episode-level state to reset."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
|
||||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
|
||||||
raise NotImplementedError("SARM model does not predict action chunks")
|
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
||||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
|
||||||
raise NotImplementedError("SARM model does not select actions")
|
|
||||||
|
|
||||||
def _train_step(
|
def _train_step(
|
||||||
self,
|
self,
|
||||||
img_emb: torch.Tensor, # (B, N, T, D)
|
img_emb: torch.Tensor, # (B, N, T, D)
|
||||||
+4
-7
@@ -1,5 +1,3 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -60,16 +58,15 @@ from lerobot.processor import (
|
|||||||
policy_action_to_transition,
|
policy_action_to_transition,
|
||||||
transition_to_policy_action,
|
transition_to_policy_action,
|
||||||
)
|
)
|
||||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
from lerobot.rewards.sarm.sarm_utils import (
|
||||||
|
|
||||||
from .configuration_sarm import SARMConfig
|
|
||||||
from .sarm_utils import (
|
|
||||||
apply_rewind_augmentation,
|
apply_rewind_augmentation,
|
||||||
compute_absolute_indices,
|
compute_absolute_indices,
|
||||||
find_stage_and_tau,
|
find_stage_and_tau,
|
||||||
pad_state_to_max_dim,
|
pad_state_to_max_dim,
|
||||||
)
|
)
|
||||||
|
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||||
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
|
|
||||||
|
|
||||||
class SARMEncodingProcessorStep(ProcessorStep):
|
class SARMEncodingProcessorStep(ProcessorStep):
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -14,14 +12,38 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation.
|
||||||
|
|
||||||
|
This module implements the SampleWeighter protocol for RA-BC training,
|
||||||
|
which weights training samples based on their task progress as measured
|
||||||
|
by the SARM reward model.
|
||||||
|
|
||||||
|
The weights are computed based on progress deltas:
|
||||||
|
delta = progress[t + chunk_size] - progress[t]
|
||||||
|
|
||||||
|
High-quality samples (positive progress) get higher weights, while
|
||||||
|
samples with negative progress (going backwards) get zero weight.
|
||||||
|
|
||||||
|
See: https://arxiv.org/abs/2509.25358 for the SARM paper.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import _pandas_available
|
||||||
|
from lerobot.utils.sample_weighting import SampleWeighter
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _pandas_available:
|
||||||
|
import pandas as pd
|
||||||
|
else:
|
||||||
|
pd = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
def resolve_hf_path(path: str | Path) -> Path:
|
def resolve_hf_path(path: str | Path) -> Path:
|
||||||
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
|
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
|
||||||
@@ -34,23 +56,27 @@ def resolve_hf_path(path: str | Path) -> Path:
|
|||||||
return Path(path)
|
return Path(path)
|
||||||
|
|
||||||
|
|
||||||
class RABCWeights:
|
class RABCWeights(SampleWeighter):
|
||||||
"""
|
"""
|
||||||
Load precomputed SARM progress values and compute RA-BC weights during training.
|
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||||||
|
|
||||||
|
This class implements the SampleWeighter ABC for use with the generic
|
||||||
|
sample weighting infrastructure in lerobot.
|
||||||
|
|
||||||
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||||||
During training, computes:
|
During training, computes:
|
||||||
- progress_delta = progress[t + chunk_size] - progress[t]
|
- progress_delta = progress[t + chunk_size] - progress[t]
|
||||||
- rabc_weight based on the delta (paper Eq. 8-9)
|
- rabc_weight based on the delta (paper Eq. 8-9)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
progress_path: Path to parquet file with precomputed progress values
|
progress_path: Path to parquet file with precomputed progress values.
|
||||||
chunk_size: Number of frames ahead for computing progress delta
|
Supports HuggingFace URLs (hf://datasets/...).
|
||||||
head_mode: Which SARM head to use ("sparse" or "dense")
|
chunk_size: Number of frames ahead for computing progress delta.
|
||||||
kappa: Hard threshold for high-quality samples (default: 0.01)
|
head_mode: Which SARM head to use ("sparse" or "dense").
|
||||||
epsilon: Small constant for numerical stability (default: 1e-6)
|
kappa: Hard threshold for high-quality samples (default: 0.01).
|
||||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
|
epsilon: Small constant for numerical stability (default: 1e-6).
|
||||||
device: Device to return tensors on
|
fallback_weight: Weight to use for frames without valid delta (default: 1.0).
|
||||||
|
device: Device to return tensors on.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -61,7 +87,7 @@ class RABCWeights:
|
|||||||
kappa: float = 0.01,
|
kappa: float = 0.01,
|
||||||
epsilon: float = 1e-6,
|
epsilon: float = 1e-6,
|
||||||
fallback_weight: float = 1.0,
|
fallback_weight: float = 1.0,
|
||||||
device: torch.device = None,
|
device: torch.device | None = None,
|
||||||
):
|
):
|
||||||
self.progress_path = resolve_hf_path(progress_path)
|
self.progress_path = resolve_hf_path(progress_path)
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
@@ -87,8 +113,8 @@ class RABCWeights:
|
|||||||
|
|
||||||
logging.info(f"Using progress column: {self.progress_column}")
|
logging.info(f"Using progress column: {self.progress_column}")
|
||||||
|
|
||||||
self.progress_lookup = {}
|
self.progress_lookup: dict[int, float] = {}
|
||||||
self.episode_lookup = {}
|
self.episode_lookup: dict[int, int] = {}
|
||||||
|
|
||||||
for _, row in self.df.iterrows():
|
for _, row in self.df.iterrows():
|
||||||
global_idx = int(row["index"])
|
global_idx = int(row["index"])
|
||||||
@@ -100,7 +126,7 @@ class RABCWeights:
|
|||||||
self.episode_lookup[global_idx] = episode_idx
|
self.episode_lookup[global_idx] = episode_idx
|
||||||
|
|
||||||
# Build episode boundaries for delta computation
|
# Build episode boundaries for delta computation
|
||||||
self.episode_boundaries = {}
|
self.episode_boundaries: dict[int, dict[str, int]] = {}
|
||||||
for episode_idx in self.df["episode_index"].unique():
|
for episode_idx in self.df["episode_index"].unique():
|
||||||
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
||||||
self.episode_boundaries[int(episode_idx)] = {
|
self.episode_boundaries[int(episode_idx)] = {
|
||||||
@@ -114,7 +140,7 @@ class RABCWeights:
|
|||||||
# Compute global statistics for weight computation
|
# Compute global statistics for weight computation
|
||||||
self._compute_global_stats()
|
self._compute_global_stats()
|
||||||
|
|
||||||
def _compute_global_stats(self):
|
def _compute_global_stats(self) -> None:
|
||||||
"""Compute global mean and std of progress deltas for weight calculation."""
|
"""Compute global mean and std of progress deltas for weight calculation."""
|
||||||
all_deltas = []
|
all_deltas = []
|
||||||
|
|
||||||
@@ -138,8 +164,8 @@ class RABCWeights:
|
|||||||
all_deltas.append(delta)
|
all_deltas.append(delta)
|
||||||
|
|
||||||
if all_deltas:
|
if all_deltas:
|
||||||
self.delta_mean = max(np.mean(all_deltas), 0.0)
|
self.delta_mean = max(float(np.mean(all_deltas)), 0.0)
|
||||||
self.delta_std = max(np.std(all_deltas), self.epsilon)
|
self.delta_std = max(float(np.std(all_deltas)), self.epsilon)
|
||||||
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
||||||
else:
|
else:
|
||||||
self.delta_mean = 0.0
|
self.delta_mean = 0.0
|
||||||
@@ -157,18 +183,19 @@ class RABCWeights:
|
|||||||
4. Compute weight using paper Eq. 8-9
|
4. Compute weight using paper Eq. 8-9
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: Training batch containing "index" key with global frame indices
|
batch: Training batch containing "index" key with global frame indices.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of:
|
Tuple of:
|
||||||
- Weights tensor (batch_size,) normalized to sum to batch_size
|
- Weights tensor (batch_size,) normalized to sum to batch_size.
|
||||||
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
|
- Stats dict with weighting statistics for logging.
|
||||||
"""
|
"""
|
||||||
indices = batch.get("index")
|
indices = batch.get("index")
|
||||||
if indices is None:
|
if indices is None:
|
||||||
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
||||||
batch_size = self._get_batch_size(batch)
|
batch_size = self._get_batch_size(batch)
|
||||||
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
|
stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size}
|
||||||
|
return torch.ones(batch_size, device=self.device), stats
|
||||||
|
|
||||||
# Convert to list of ints
|
# Convert to list of ints
|
||||||
if isinstance(indices, torch.Tensor):
|
if isinstance(indices, torch.Tensor):
|
||||||
@@ -183,29 +210,29 @@ class RABCWeights:
|
|||||||
delta = self._compute_delta(idx)
|
delta = self._compute_delta(idx)
|
||||||
deltas.append(delta)
|
deltas.append(delta)
|
||||||
|
|
||||||
deltas = np.array(deltas, dtype=np.float32)
|
deltas_array = np.array(deltas, dtype=np.float32)
|
||||||
|
|
||||||
# Compute weights from deltas
|
# Compute weights from deltas
|
||||||
weights = self._compute_weights(deltas)
|
weights = self._compute_weights(deltas_array)
|
||||||
|
|
||||||
# Compute stats before normalization for logging
|
# Compute stats before normalization for logging
|
||||||
raw_mean_weight = float(np.nanmean(weights))
|
raw_mean_weight = float(np.nanmean(weights))
|
||||||
num_zero_weight = int(np.sum(weights == 0))
|
num_zero_weight = int(np.sum(weights == 0))
|
||||||
num_full_weight = int(np.sum(weights == 1.0))
|
num_full_weight = int(np.sum(weights == 1.0))
|
||||||
batch_stats = {
|
batch_stats = {
|
||||||
"raw_mean_weight": raw_mean_weight,
|
"mean_weight": raw_mean_weight,
|
||||||
"num_zero_weight": num_zero_weight,
|
"num_zero_weight": num_zero_weight,
|
||||||
"num_full_weight": num_full_weight,
|
"num_full_weight": num_full_weight,
|
||||||
}
|
}
|
||||||
|
|
||||||
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||||
|
|
||||||
# Normalize to sum to batch_size
|
# Normalize to sum to batch_size
|
||||||
batch_size = len(weights)
|
batch_size = len(weights_tensor)
|
||||||
weight_sum = weights.sum() + self.epsilon
|
weight_sum = weights_tensor.sum() + self.epsilon
|
||||||
weights = weights * batch_size / weight_sum
|
weights_tensor = weights_tensor * batch_size / weight_sum
|
||||||
|
|
||||||
return weights, batch_stats
|
return weights_tensor, batch_stats
|
||||||
|
|
||||||
def _compute_delta(self, global_idx: int) -> float:
|
def _compute_delta(self, global_idx: int) -> float:
|
||||||
"""Compute progress delta for a single frame."""
|
"""Compute progress delta for a single frame."""
|
||||||
@@ -241,7 +268,7 @@ class RABCWeights:
|
|||||||
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Array of weights
|
Array of weights.
|
||||||
"""
|
"""
|
||||||
valid_mask = ~np.isnan(deltas)
|
valid_mask = ~np.isnan(deltas)
|
||||||
|
|
||||||
@@ -273,12 +300,13 @@ class RABCWeights:
|
|||||||
if key in batch:
|
if key in batch:
|
||||||
val = batch[key]
|
val = batch[key]
|
||||||
if isinstance(val, (torch.Tensor, np.ndarray)):
|
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||||||
return val.shape[0]
|
return int(val.shape[0])
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def get_stats(self) -> dict:
|
def get_stats(self) -> dict:
|
||||||
"""Get statistics."""
|
"""Get global statistics about the RA-BC weighting."""
|
||||||
return {
|
return {
|
||||||
|
"type": "rabc",
|
||||||
"num_frames": len(self.progress_lookup),
|
"num_frames": len(self.progress_lookup),
|
||||||
"chunk_size": self.chunk_size,
|
"chunk_size": self.chunk_size,
|
||||||
"head_mode": self.head_mode,
|
"head_mode": self.head_mode,
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
|||||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||||
|
from lerobot.rewards import make_reward_pre_post_processors
|
||||||
from lerobot.utils.import_utils import register_third_party_plugins
|
from lerobot.utils.import_utils import register_third_party_plugins
|
||||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
from lerobot.utils.random_utils import set_seed
|
from lerobot.utils.random_utils import set_seed
|
||||||
@@ -70,8 +71,8 @@ def update_policy(
|
|||||||
accelerator: "Accelerator",
|
accelerator: "Accelerator",
|
||||||
lr_scheduler=None,
|
lr_scheduler=None,
|
||||||
lock=None,
|
lock=None,
|
||||||
rabc_weights_provider=None,
|
sample_weighter=None,
|
||||||
) -> tuple[MetricsTracker, dict]:
|
) -> tuple[MetricsTracker, dict | None]:
|
||||||
"""
|
"""
|
||||||
Performs a single training step to update the policy's weights.
|
Performs a single training step to update the policy's weights.
|
||||||
|
|
||||||
@@ -87,7 +88,7 @@ def update_policy(
|
|||||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||||
lr_scheduler: An optional learning rate scheduler.
|
lr_scheduler: An optional learning rate scheduler.
|
||||||
lock: An optional lock for thread-safe optimizer updates.
|
lock: An optional lock for thread-safe optimizer updates.
|
||||||
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
|
sample_weighter: Optional SampleWeighter instance for per-sample loss weighting.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing:
|
A tuple containing:
|
||||||
@@ -97,27 +98,31 @@ def update_policy(
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
# Get RA-BC weights if enabled
|
# Compute sample weights if a weighter is provided
|
||||||
rabc_batch_weights = None
|
sample_weights = None
|
||||||
rabc_batch_stats = None
|
weight_stats = None
|
||||||
if rabc_weights_provider is not None:
|
if sample_weighter is not None:
|
||||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
|
||||||
|
|
||||||
# Let accelerator handle mixed precision
|
# Let accelerator handle mixed precision
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
if sample_weights is not None:
|
||||||
if rabc_batch_weights is not None:
|
# Use per-sample loss for weighted training
|
||||||
# Get per-sample losses
|
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
|
||||||
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||||
|
|
||||||
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
|
# Weighted loss: each sample's contribution is scaled by its weight.
|
||||||
# rabc_batch_weights is already normalized to sum to batch_size
|
# We divide by weight sum (not batch size) so that if some weights are zero,
|
||||||
|
# the remaining samples contribute proportionally more, preserving gradient scale.
|
||||||
|
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
|
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
|
||||||
# Log raw mean weight (before normalization) - this is the meaningful metric
|
|
||||||
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
|
# Log weighting statistics
|
||||||
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
|
if output_dict is None:
|
||||||
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
output_dict = {}
|
||||||
|
for key, value in weight_stats.items():
|
||||||
|
output_dict[f"sample_weight_{key}"] = value
|
||||||
else:
|
else:
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = policy.forward(batch)
|
||||||
|
|
||||||
@@ -188,8 +193,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
|
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||||
# Force the device to be CPU when policy.device is set to CPU.
|
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||||
force_cpu = cfg.policy.device == "cpu"
|
force_cpu = cfg.trainable_config.device == "cpu"
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
step_scheduler_with_optimizer=False,
|
step_scheduler_with_optimizer=False,
|
||||||
kwargs_handlers=[ddp_kwargs],
|
kwargs_handlers=[ddp_kwargs],
|
||||||
@@ -245,26 +250,44 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
logging.info("Creating env")
|
logging.info("Creating env")
|
||||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||||
|
|
||||||
if is_main_process:
|
if cfg.is_reward_model_training:
|
||||||
logging.info("Creating policy")
|
if is_main_process:
|
||||||
policy = make_policy(
|
logging.info("Creating reward model")
|
||||||
cfg=cfg.policy,
|
from lerobot.rewards import make_reward_model
|
||||||
ds_meta=dataset.meta,
|
|
||||||
rename_map=cfg.rename_map,
|
policy = make_reward_model(
|
||||||
)
|
cfg=cfg.reward_model,
|
||||||
|
dataset_stats=dataset.meta.stats,
|
||||||
|
dataset_meta=dataset.meta,
|
||||||
|
)
|
||||||
|
if not policy.is_trainable:
|
||||||
|
raise ValueError(
|
||||||
|
f"Reward model '{policy.name}' is zero-shot and cannot be trained via lerobot-train. "
|
||||||
|
"Use it directly for inference via compute_reward() (e.g. offline precompute)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if is_main_process:
|
||||||
|
logging.info("Creating policy")
|
||||||
|
policy = make_policy(
|
||||||
|
cfg=cfg.policy,
|
||||||
|
ds_meta=dataset.meta,
|
||||||
|
rename_map=cfg.rename_map,
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.peft is not None:
|
if cfg.peft is not None:
|
||||||
|
if cfg.is_reward_model_training:
|
||||||
|
raise ValueError("PEFT is only supported for policy training. ")
|
||||||
logging.info("Using PEFT! Wrapping model.")
|
logging.info("Using PEFT! Wrapping model.")
|
||||||
# Convert CLI peft config to dict for overrides
|
|
||||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||||
|
|
||||||
# Wait for all processes to finish policy creation before continuing
|
# Wait for all processes to finish model creation before continuing
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
processor_pretrained_path = cfg.policy.pretrained_path
|
active_cfg = cfg.trainable_config
|
||||||
|
processor_pretrained_path = active_cfg.pretrained_path
|
||||||
if (
|
if (
|
||||||
getattr(cfg.policy, "use_relative_actions", False)
|
getattr(active_cfg, "use_relative_actions", False)
|
||||||
and processor_pretrained_path is not None
|
and processor_pretrained_path is not None
|
||||||
and not cfg.resume
|
and not cfg.resume
|
||||||
):
|
):
|
||||||
@@ -274,18 +297,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
)
|
)
|
||||||
processor_pretrained_path = None
|
processor_pretrained_path = None
|
||||||
|
|
||||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
postprocessor_kwargs = {}
|
postprocessor_kwargs = {}
|
||||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||||
# Only provide dataset_stats when not resuming from saved processor state
|
|
||||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||||
|
|
||||||
# For SARM, always provide dataset_meta for progress normalization
|
if cfg.is_reward_model_training:
|
||||||
if cfg.policy.type == "sarm":
|
|
||||||
processor_kwargs["dataset_meta"] = dataset.meta
|
processor_kwargs["dataset_meta"] = dataset.meta
|
||||||
|
|
||||||
if processor_pretrained_path is not None:
|
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||||
processor_kwargs["preprocessor_overrides"] = {
|
processor_kwargs["preprocessor_overrides"] = {
|
||||||
"device_processor": {"device": device.type},
|
"device_processor": {"device": device.type},
|
||||||
"normalizer_processor": {
|
"normalizer_processor": {
|
||||||
@@ -305,38 +325,36 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
if cfg.is_reward_model_training:
|
||||||
policy_cfg=cfg.policy,
|
preprocessor, postprocessor = make_reward_pre_post_processors(
|
||||||
pretrained_path=processor_pretrained_path,
|
cfg.reward_model,
|
||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
**postprocessor_kwargs,
|
)
|
||||||
)
|
else:
|
||||||
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
|
policy_cfg=cfg.policy,
|
||||||
|
pretrained_path=processor_pretrained_path,
|
||||||
|
**processor_kwargs,
|
||||||
|
**postprocessor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info("Creating optimizer and scheduler")
|
logging.info("Creating optimizer and scheduler")
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
|
|
||||||
# Load precomputed SARM progress for RA-BC if enabled
|
# Create sample weighter if configured (e.g., for RA-BC training)
|
||||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
sample_weighter = None
|
||||||
rabc_weights = None
|
if cfg.sample_weighting is not None:
|
||||||
if cfg.use_rabc:
|
from lerobot.utils.sample_weighting import make_sample_weighter
|
||||||
from lerobot.utils.rabc import RABCWeights
|
|
||||||
|
|
||||||
# Get chunk_size from policy config
|
if is_main_process:
|
||||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
|
||||||
if chunk_size is None:
|
sample_weighter = make_sample_weighter(
|
||||||
raise ValueError("Chunk size is not found in policy config")
|
cfg.sample_weighting,
|
||||||
|
policy,
|
||||||
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
|
device,
|
||||||
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
|
dataset_root=cfg.dataset.root,
|
||||||
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
|
dataset_repo_id=cfg.dataset.repo_id,
|
||||||
rabc_weights = RABCWeights(
|
|
||||||
progress_path=cfg.rabc_progress_path,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
head_mode=head_mode,
|
|
||||||
kappa=getattr(cfg, "rabc_kappa", 0.01),
|
|
||||||
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
step = 0 # number of policy updates (forward + backward + optim)
|
step = 0 # number of policy updates (forward + backward + optim)
|
||||||
@@ -365,13 +383,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||||
shuffle = False
|
shuffle = False
|
||||||
sampler = EpisodeAwareSampler(
|
sampler = EpisodeAwareSampler(
|
||||||
dataset.meta.episodes["dataset_from_index"],
|
dataset.meta.episodes["dataset_from_index"],
|
||||||
dataset.meta.episodes["dataset_to_index"],
|
dataset.meta.episodes["dataset_to_index"],
|
||||||
episode_indices_to_use=dataset.episodes,
|
episode_indices_to_use=dataset.episodes,
|
||||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -448,7 +466,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
cfg.optimizer.grad_clip_norm,
|
cfg.optimizer.grad_clip_norm,
|
||||||
accelerator=accelerator,
|
accelerator=accelerator,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
rabc_weights_provider=rabc_weights,
|
sample_weighter=sample_weighter,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
@@ -467,16 +485,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
wandb_log_dict = train_tracker.to_dict()
|
wandb_log_dict = train_tracker.to_dict()
|
||||||
if output_dict:
|
if output_dict:
|
||||||
wandb_log_dict.update(output_dict)
|
wandb_log_dict.update(output_dict)
|
||||||
# Log RA-BC statistics if enabled
|
# Log sample weighting statistics if enabled
|
||||||
if rabc_weights is not None:
|
if sample_weighter is not None:
|
||||||
rabc_stats = rabc_weights.get_stats()
|
weighter_stats = sample_weighter.get_stats()
|
||||||
wandb_log_dict.update(
|
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||||
{
|
|
||||||
"rabc_delta_mean": rabc_stats["delta_mean"],
|
|
||||||
"rabc_delta_std": rabc_stats["delta_std"],
|
|
||||||
"rabc_num_frames": rabc_stats["num_frames"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
wandb_logger.log_dict(wandb_log_dict, step)
|
wandb_logger.log_dict(wandb_log_dict, step)
|
||||||
train_tracker.reset_averages()
|
train_tracker.reset_averages()
|
||||||
|
|
||||||
@@ -558,14 +570,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
if cfg.policy.push_to_hub:
|
if getattr(active_cfg, "push_to_hub", False):
|
||||||
unwrapped_policy = accelerator.unwrap_model(policy)
|
unwrapped_model = accelerator.unwrap_model(policy)
|
||||||
if cfg.policy.use_peft:
|
# PEFT only applies when training a policy — reward models use the plain path.
|
||||||
unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy)
|
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
||||||
|
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
||||||
else:
|
else:
|
||||||
unwrapped_policy.push_model_to_hub(cfg)
|
unwrapped_model.push_model_to_hub(cfg)
|
||||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
preprocessor.push_to_hub(active_cfg.repo_id)
|
||||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
postprocessor.push_to_hub(active_cfg.repo_id)
|
||||||
|
|
||||||
# Properly clean up the distributed process group
|
# Properly clean up the distributed process group
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
---
|
||||||
|
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
||||||
|
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
||||||
|
# prettier-ignore
|
||||||
|
{{card_data}}
|
||||||
|
---
|
||||||
|
|
||||||
|
# Reward Model Card for {{ model_name | default("Reward Model ID", true) }}
|
||||||
|
|
||||||
|
<!-- Provide a quick summary of what the reward model is/does. -->
|
||||||
|
|
||||||
|
{% if model_name == "reward_classifier" %}
|
||||||
|
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
|
||||||
|
{% elif model_name == "sarm" %}
|
||||||
|
A Success-Aware Reward Model (SARM) predicts a dense reward signal from observations, typically used downstream for reinforcement learning or human-in-the-loop fine-tuning when task success is not directly observable.
|
||||||
|
{% else %}
|
||||||
|
_Reward model type not recognized — please update this template._
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
This reward model has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot).
|
||||||
|
See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## How to Get Started with the Reward Model
|
||||||
|
|
||||||
|
### Train from scratch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=${HF_USER}/<dataset> \
|
||||||
|
--reward_model.type={{ model_name | default("reward_classifier", true) }} \
|
||||||
|
--output_dir=outputs/train/<desired_reward_model_repo_id> \
|
||||||
|
--job_name=lerobot_reward_training \
|
||||||
|
--reward_model.device=cuda \
|
||||||
|
--reward_model.repo_id=${HF_USER}/<desired_reward_model_repo_id> \
|
||||||
|
--wandb.enable=true
|
||||||
|
```
|
||||||
|
|
||||||
|
_Writes checkpoints to `outputs/train/<desired_reward_model_repo_id>/checkpoints/`._
|
||||||
|
|
||||||
|
### Load the reward model in Python
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.rewards import make_reward_model
|
||||||
|
|
||||||
|
reward_model = make_reward_model(pretrained_path="<hf_user>/<reward_model_repo_id>")
|
||||||
|
reward = reward_model.compute_reward(batch)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Model Details
|
||||||
|
|
||||||
|
- **License:** {{ license | default("\[More Information Needed]", true) }}
|
||||||
@@ -0,0 +1,239 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Sample weighting abstraction for training.
|
||||||
|
|
||||||
|
This module provides an abstract base class for sample weighting strategies (e.g., RA-BC)
|
||||||
|
that can be used during training without polluting the training script with
|
||||||
|
policy-specific code.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
# In training config
|
||||||
|
sample_weighting:
|
||||||
|
type: rabc
|
||||||
|
progress_path: hf://datasets/my-dataset/sarm_progress.parquet
|
||||||
|
head_mode: sparse
|
||||||
|
kappa: 0.01
|
||||||
|
|
||||||
|
# In training script
|
||||||
|
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device, dataset_root=cfg.dataset.root, dataset_repo_id=cfg.dataset.repo_id)
|
||||||
|
...
|
||||||
|
weights, stats = sample_weighter.compute_batch_weights(batch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class SampleWeighter(ABC):
|
||||||
|
"""
|
||||||
|
Implementations compute per-sample weights that can be used to weight
|
||||||
|
the loss during training. This enables techniques like:
|
||||||
|
- RA-BC (Reward-Aligned Behavior Cloning)
|
||||||
|
- Importance sampling
|
||||||
|
- Curriculum learning
|
||||||
|
- Quality-based filtering
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||||
|
"""
|
||||||
|
Compute per-sample weights for a training batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Training batch dictionary containing at minimum an "index" key
|
||||||
|
with global frame indices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get global statistics about the weighting strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SampleWeightingConfig:
|
||||||
|
"""
|
||||||
|
Configuration for sample weighting during training.
|
||||||
|
|
||||||
|
This is a generic config that supports multiple weighting strategies.
|
||||||
|
The `type` field determines which implementation to use, and `extra_params`
|
||||||
|
contains additional type-specific parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Weighting strategy type ("rabc", "uniform", etc.)
|
||||||
|
progress_path: Path to precomputed progress values (for RABC)
|
||||||
|
head_mode: Which model head to use for progress ("sparse" or "dense")
|
||||||
|
kappa: Hard threshold for high-quality samples (RABC-specific)
|
||||||
|
epsilon: Small constant for numerical stability
|
||||||
|
extra_params: Additional type-specific parameters passed to the weighter
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "rabc"
|
||||||
|
progress_path: str | None = None
|
||||||
|
head_mode: str = "sparse"
|
||||||
|
kappa: float = 0.01
|
||||||
|
epsilon: float = 1e-6
|
||||||
|
# Additional type-specific params can be added here or passed via extra_params
|
||||||
|
extra_params: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def make_sample_weighter(
|
||||||
|
config: SampleWeightingConfig | None,
|
||||||
|
policy: PreTrainedPolicy,
|
||||||
|
device: torch.device,
|
||||||
|
dataset_root: str | None = None,
|
||||||
|
dataset_repo_id: str | None = None,
|
||||||
|
) -> SampleWeighter | None:
|
||||||
|
"""
|
||||||
|
Factory function to create a SampleWeighter from config.
|
||||||
|
|
||||||
|
This keeps policy-specific initialization logic out of the training script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Sample weighting configuration, or None to disable weighting.
|
||||||
|
policy: The policy being trained (used to extract chunk_size, etc.)
|
||||||
|
device: Device to place weight tensors on.
|
||||||
|
dataset_root: Local path to dataset root (for auto-detecting progress_path).
|
||||||
|
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if config.type == "rabc":
|
||||||
|
return _make_rabc_weighter(config, policy, device, dataset_root, dataset_repo_id)
|
||||||
|
|
||||||
|
if config.type == "uniform":
|
||||||
|
# No-op weighter that returns uniform weights
|
||||||
|
return UniformWeighter(device=device)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown sample weighting type: '{config.type}'. Supported types: 'rabc', 'uniform'")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_rabc_weighter(
|
||||||
|
config: SampleWeightingConfig,
|
||||||
|
policy: PreTrainedPolicy,
|
||||||
|
device: torch.device,
|
||||||
|
dataset_root: str | None = None,
|
||||||
|
dataset_repo_id: str | None = None,
|
||||||
|
) -> SampleWeighter:
|
||||||
|
"""Create RABC weighter with policy-specific initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Sample weighting configuration.
|
||||||
|
policy: The policy being trained (used to extract chunk_size).
|
||||||
|
device: Device to place weight tensors on.
|
||||||
|
dataset_root: Local path to dataset root (for auto-detecting progress_path).
|
||||||
|
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular imports and keep RABC code in SARM module
|
||||||
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||||
|
|
||||||
|
# Extract chunk_size from policy config
|
||||||
|
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||||
|
if chunk_size is None:
|
||||||
|
raise ValueError(
|
||||||
|
"RABC sample weighting requires a policy with 'chunk_size' in its config. "
|
||||||
|
"This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine progress_path: use explicit config or auto-detect from dataset
|
||||||
|
progress_path = config.progress_path
|
||||||
|
if progress_path is None:
|
||||||
|
if dataset_root:
|
||||||
|
progress_path = str(Path(dataset_root) / "sarm_progress.parquet")
|
||||||
|
elif dataset_repo_id:
|
||||||
|
progress_path = f"hf://datasets/{dataset_repo_id}/sarm_progress.parquet"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"RABC sample weighting requires 'progress_path' to be set, "
|
||||||
|
"or dataset_root/dataset_repo_id for auto-detection. "
|
||||||
|
"Generate progress values using: "
|
||||||
|
"python -m lerobot.rewards.sarm.compute_rabc_weights --help"
|
||||||
|
)
|
||||||
|
|
||||||
|
return RABCWeights(
|
||||||
|
progress_path=progress_path,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
head_mode=config.head_mode,
|
||||||
|
kappa=config.kappa,
|
||||||
|
epsilon=config.epsilon,
|
||||||
|
device=device,
|
||||||
|
**config.extra_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UniformWeighter(SampleWeighter):
|
||||||
|
"""
|
||||||
|
No-op sample weighter that returns uniform weights.
|
||||||
|
|
||||||
|
Useful as a baseline or when you want to disable weighting without
|
||||||
|
changing the training code structure.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Batch size is determined by looking for tensor values in the batch
|
||||||
|
dictionary. The method checks common keys like "action", "index",
|
||||||
|
and "observation.state" first, then falls back to scanning all values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: torch.device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||||
|
"""Return uniform weights (all ones)."""
|
||||||
|
batch_size = self._determine_batch_size(batch)
|
||||||
|
|
||||||
|
weights = torch.ones(batch_size, device=self.device)
|
||||||
|
stats = {"mean_weight": 1.0, "type": "uniform"}
|
||||||
|
return weights, stats
|
||||||
|
|
||||||
|
def _determine_batch_size(self, batch: dict) -> int:
|
||||||
|
"""
|
||||||
|
Determine batch size from the batch dictionary.
|
||||||
|
|
||||||
|
Checks common keys first, then scans all values for tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Training batch dictionary.
|
||||||
|
"""
|
||||||
|
if not batch:
|
||||||
|
raise ValueError("Cannot determine batch size from empty batch")
|
||||||
|
|
||||||
|
# Check common keys first
|
||||||
|
for key in ["action", "index", "observation.state"]:
|
||||||
|
if key in batch and isinstance(batch[key], torch.Tensor):
|
||||||
|
return batch[key].shape[0]
|
||||||
|
|
||||||
|
# Scan all values for any tensor
|
||||||
|
for value in batch.values():
|
||||||
|
if isinstance(value, torch.Tensor) and value.ndim >= 1:
|
||||||
|
return value.shape[0]
|
||||||
|
|
||||||
|
# Last resort: return 1 (this handles non-tensor batches)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Return empty stats for uniform weighting."""
|
||||||
|
return {"type": "uniform"}
|
||||||
+10
-34
@@ -1,5 +1,3 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -21,8 +19,6 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
|
||||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
DataProcessorPipeline,
|
DataProcessorPipeline,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
@@ -31,6 +27,8 @@ from lerobot.processor import (
|
|||||||
TransitionKey,
|
TransitionKey,
|
||||||
)
|
)
|
||||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||||
|
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||||
|
from lerobot.rewards.classifier.processor_classifier import make_classifier_processor
|
||||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
@@ -42,7 +40,7 @@ def create_default_config():
|
|||||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
}
|
}
|
||||||
config.output_features = {
|
config.output_features = {
|
||||||
"reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), # Classifier output
|
"reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||||
}
|
}
|
||||||
config.normalization_mapping = {
|
config.normalization_mapping = {
|
||||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||||
@@ -90,17 +88,14 @@ def test_classifier_processor_normalization():
|
|||||||
config = create_default_config()
|
config = create_default_config()
|
||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
preprocessor, postprocessor = make_classifier_processor(
|
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||||
config,
|
|
||||||
stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create test data
|
# Create test data
|
||||||
observation = {
|
observation = {
|
||||||
OBS_STATE: torch.randn(10),
|
OBS_STATE: torch.randn(10),
|
||||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||||
}
|
}
|
||||||
action = torch.randn(1) # Dummy action/reward
|
action = torch.randn(1)
|
||||||
transition = create_transition(observation, action)
|
transition = create_transition(observation, action)
|
||||||
batch = transition_to_batch(transition)
|
batch = transition_to_batch(transition)
|
||||||
|
|
||||||
@@ -120,10 +115,7 @@ def test_classifier_processor_cuda():
|
|||||||
config.device = "cuda"
|
config.device = "cuda"
|
||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
preprocessor, postprocessor = make_classifier_processor(
|
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||||
config,
|
|
||||||
stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create CPU data
|
# Create CPU data
|
||||||
observation = {
|
observation = {
|
||||||
@@ -132,7 +124,6 @@ def test_classifier_processor_cuda():
|
|||||||
}
|
}
|
||||||
action = torch.randn(1)
|
action = torch.randn(1)
|
||||||
transition = create_transition(observation, action)
|
transition = create_transition(observation, action)
|
||||||
|
|
||||||
batch = transition_to_batch(transition)
|
batch = transition_to_batch(transition)
|
||||||
|
|
||||||
# Process through preprocessor
|
# Process through preprocessor
|
||||||
@@ -158,10 +149,7 @@ def test_classifier_processor_accelerate_scenario():
|
|||||||
config.device = "cuda:0"
|
config.device = "cuda:0"
|
||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
preprocessor, postprocessor = make_classifier_processor(
|
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||||
config,
|
|
||||||
stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simulate Accelerate: data already on GPU
|
# Simulate Accelerate: data already on GPU
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
@@ -171,7 +159,6 @@ def test_classifier_processor_accelerate_scenario():
|
|||||||
}
|
}
|
||||||
action = torch.randn(1).to(device)
|
action = torch.randn(1).to(device)
|
||||||
transition = create_transition(observation, action)
|
transition = create_transition(observation, action)
|
||||||
|
|
||||||
batch = transition_to_batch(transition)
|
batch = transition_to_batch(transition)
|
||||||
|
|
||||||
# Process through preprocessor
|
# Process through preprocessor
|
||||||
@@ -201,7 +188,6 @@ def test_classifier_processor_multi_gpu():
|
|||||||
}
|
}
|
||||||
action = torch.randn(1).to(device)
|
action = torch.randn(1).to(device)
|
||||||
transition = create_transition(observation, action)
|
transition = create_transition(observation, action)
|
||||||
|
|
||||||
batch = transition_to_batch(transition)
|
batch = transition_to_batch(transition)
|
||||||
|
|
||||||
# Process through preprocessor
|
# Process through preprocessor
|
||||||
@@ -231,7 +217,6 @@ def test_classifier_processor_without_stats():
|
|||||||
}
|
}
|
||||||
action = torch.randn(1)
|
action = torch.randn(1)
|
||||||
transition = create_transition(observation, action)
|
transition = create_transition(observation, action)
|
||||||
|
|
||||||
batch = transition_to_batch(transition)
|
batch = transition_to_batch(transition)
|
||||||
|
|
||||||
processed = preprocessor(batch)
|
processed = preprocessor(batch)
|
||||||
@@ -294,7 +279,6 @@ def test_classifier_processor_mixed_precision():
|
|||||||
}
|
}
|
||||||
action = torch.randn(1, dtype=torch.float32)
|
action = torch.randn(1, dtype=torch.float32)
|
||||||
transition = create_transition(observation, action)
|
transition = create_transition(observation, action)
|
||||||
|
|
||||||
batch = transition_to_batch(transition)
|
batch = transition_to_batch(transition)
|
||||||
|
|
||||||
# Process through preprocessor
|
# Process through preprocessor
|
||||||
@@ -312,10 +296,7 @@ def test_classifier_processor_batch_data():
|
|||||||
config = create_default_config()
|
config = create_default_config()
|
||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
preprocessor, postprocessor = make_classifier_processor(
|
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||||
config,
|
|
||||||
stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test with batched data
|
# Test with batched data
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
@@ -325,7 +306,6 @@ def test_classifier_processor_batch_data():
|
|||||||
}
|
}
|
||||||
action = torch.randn(batch_size, 1)
|
action = torch.randn(batch_size, 1)
|
||||||
transition = create_transition(observation, action)
|
transition = create_transition(observation, action)
|
||||||
|
|
||||||
batch = transition_to_batch(transition)
|
batch = transition_to_batch(transition)
|
||||||
|
|
||||||
# Process through preprocessor
|
# Process through preprocessor
|
||||||
@@ -343,15 +323,11 @@ def test_classifier_processor_postprocessor_identity():
|
|||||||
config = create_default_config()
|
config = create_default_config()
|
||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
preprocessor, postprocessor = make_classifier_processor(
|
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||||
config,
|
|
||||||
stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create test data for postprocessor
|
# Create test data for postprocessor
|
||||||
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
|
reward = torch.tensor([[0.8], [0.3], [0.9]])
|
||||||
transition = create_transition(action=reward)
|
transition = create_transition(action=reward)
|
||||||
|
|
||||||
_ = transition_to_batch(transition)
|
_ = transition_to_batch(transition)
|
||||||
|
|
||||||
# Process through postprocessor
|
# Process through postprocessor
|
||||||
+15
-9
@@ -1,5 +1,3 @@
|
|||||||
# !/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -18,8 +16,8 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
from lerobot.rewards.classifier.modeling_classifier import ClassifierOutput
|
||||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||||
from tests.utils import skip_if_package_missing
|
from tests.utils import skip_if_package_missing
|
||||||
|
|
||||||
@@ -42,7 +40,7 @@ def test_classifier_output():
|
|||||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
)
|
)
|
||||||
def test_binary_classifier_with_default_params():
|
def test_binary_classifier_with_default_params():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
config = RewardClassifierConfig()
|
config = RewardClassifierConfig()
|
||||||
config.input_features = {
|
config.input_features = {
|
||||||
@@ -86,7 +84,7 @@ def test_binary_classifier_with_default_params():
|
|||||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
)
|
)
|
||||||
def test_multiclass_classifier():
|
def test_multiclass_classifier():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
num_classes = 5
|
num_classes = 5
|
||||||
config = RewardClassifierConfig()
|
config = RewardClassifierConfig()
|
||||||
@@ -128,11 +126,15 @@ def test_multiclass_classifier():
|
|||||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
)
|
)
|
||||||
def test_default_device():
|
def test_default_device():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
config = RewardClassifierConfig()
|
config = RewardClassifierConfig()
|
||||||
assert config.device == "cpu"
|
assert config.device is None or config.device == "cpu"
|
||||||
|
|
||||||
|
config.input_features = {
|
||||||
|
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
|
}
|
||||||
|
config.num_cameras = 1
|
||||||
classifier = Classifier(config)
|
classifier = Classifier(config)
|
||||||
for p in classifier.parameters():
|
for p in classifier.parameters():
|
||||||
assert p.device == torch.device("cpu")
|
assert p.device == torch.device("cpu")
|
||||||
@@ -143,11 +145,15 @@ def test_default_device():
|
|||||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||||
)
|
)
|
||||||
def test_explicit_device_setup():
|
def test_explicit_device_setup():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||||
|
|
||||||
config = RewardClassifierConfig(device="cpu")
|
config = RewardClassifierConfig(device="cpu")
|
||||||
assert config.device == "cpu"
|
assert config.device == "cpu"
|
||||||
|
|
||||||
|
config.input_features = {
|
||||||
|
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
|
}
|
||||||
|
config.num_cameras = 1
|
||||||
classifier = Classifier(config)
|
classifier = Classifier(config)
|
||||||
for p in classifier.parameters():
|
for p in classifier.parameters():
|
||||||
assert p.device == torch.device("cpu")
|
assert p.device == torch.device("cpu")
|
||||||
@@ -0,0 +1,373 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for the reward model base classes and registry."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
|
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||||
|
|
||||||
|
|
||||||
|
@RewardModelConfig.register_subclass(name="_dummy_hub_reward")
|
||||||
|
@dataclass
|
||||||
|
class _DummyHubRewardConfig(RewardModelConfig):
|
||||||
|
def get_optimizer_preset(self):
|
||||||
|
return AdamWConfig(lr=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyHubReward(PreTrainedRewardModel):
|
||||||
|
config_class = _DummyHubRewardConfig
|
||||||
|
name = "_dummy_hub_reward"
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.bias = torch.nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
|
def compute_reward(self, batch):
|
||||||
|
return self.bias.expand(1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reward_model_config_registry():
|
||||||
|
"""Verify that classifier and sarm are registered."""
|
||||||
|
known = RewardModelConfig.get_known_choices()
|
||||||
|
assert "reward_classifier" in known
|
||||||
|
assert "sarm" in known
|
||||||
|
|
||||||
|
|
||||||
|
def test_reward_model_config_lookup():
|
||||||
|
"""Verify that we can look up configs by name."""
|
||||||
|
cls = RewardModelConfig.get_choice_class("reward_classifier")
|
||||||
|
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||||
|
|
||||||
|
assert cls is RewardClassifierConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_get_reward_model_class():
|
||||||
|
"""Test the get_reward_model_class factory."""
|
||||||
|
from lerobot.rewards.factory import get_reward_model_class
|
||||||
|
|
||||||
|
cls = get_reward_model_class("sarm")
|
||||||
|
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||||
|
|
||||||
|
assert cls is SARMRewardModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_unknown_raises():
|
||||||
|
"""Unknown name should raise ValueError."""
|
||||||
|
from lerobot.rewards.factory import get_reward_model_class
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not available"):
|
||||||
|
get_reward_model_class("nonexistent_reward_model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_pretrained_reward_model_requires_config_class():
|
||||||
|
"""Subclass without config_class should fail."""
|
||||||
|
with pytest.raises(TypeError, match="must define 'config_class'"):
|
||||||
|
|
||||||
|
class BadModel(PreTrainedRewardModel):
|
||||||
|
name = "bad"
|
||||||
|
|
||||||
|
def compute_reward(self, batch):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_pretrained_reward_model_requires_name():
|
||||||
|
"""Subclass without name should fail."""
|
||||||
|
with pytest.raises(TypeError, match="must define 'name'"):
|
||||||
|
|
||||||
|
class BadModel(PreTrainedRewardModel):
|
||||||
|
config_class = RewardModelConfig
|
||||||
|
|
||||||
|
def compute_reward(self, batch):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_trainable_forward_raises():
|
||||||
|
"""Non-trainable model should raise on forward()."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DummyConfig(RewardModelConfig):
|
||||||
|
def get_optimizer_preset(self):
|
||||||
|
return AdamWConfig(lr=1e-4)
|
||||||
|
|
||||||
|
class DummyReward(PreTrainedRewardModel):
|
||||||
|
config_class = DummyConfig
|
||||||
|
name = "dummy_test"
|
||||||
|
|
||||||
|
def compute_reward(self, batch):
|
||||||
|
return torch.zeros(1)
|
||||||
|
|
||||||
|
config = DummyConfig()
|
||||||
|
model = DummyReward(config)
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError, match="not trainable"):
|
||||||
|
model.forward({"x": torch.zeros(1)})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trainable vs zero-shot (general-purpose) reward models.
|
||||||
|
# The proposal explicitly supports models like TOPReward that wrap a pretrained
|
||||||
|
# VLM and produce a reward signal without any training step. These tests pin
|
||||||
|
# the contract that lets such models coexist with trainable ones.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_trainable_false_when_forward_not_overridden():
|
||||||
|
"""A reward model that only implements ``compute_reward`` is zero-shot."""
|
||||||
|
model, _ = _make_dummy_reward_model()
|
||||||
|
assert model.is_trainable is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_trainable_true_when_forward_overridden():
|
||||||
|
"""Overriding ``forward`` flips ``is_trainable`` to True."""
|
||||||
|
|
||||||
|
class _TrainableReward(_DummyHubReward):
|
||||||
|
name = "_trainable_dummy_reward"
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
loss = (self.bias**2).sum()
|
||||||
|
return loss, {}
|
||||||
|
|
||||||
|
# Register a fresh config subclass so the subclass check passes.
|
||||||
|
@RewardModelConfig.register_subclass(name="_trainable_dummy_reward")
|
||||||
|
@dataclass
|
||||||
|
class _TrainableConfig(_DummyHubRewardConfig):
|
||||||
|
pass
|
||||||
|
|
||||||
|
_TrainableReward.config_class = _TrainableConfig
|
||||||
|
model = _TrainableReward(_TrainableConfig())
|
||||||
|
assert model.is_trainable is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RewardModelConfig.from_pretrained
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_reward_model_config_from_pretrained_raises_when_config_missing(tmp_path):
|
||||||
|
"""``from_pretrained`` must surface a clear ``FileNotFoundError`` when the
|
||||||
|
target directory exists but does not contain ``config.json``, instead of
|
||||||
|
crashing later inside ``draccus.parse``.
|
||||||
|
"""
|
||||||
|
# tmp_path exists but has no config.json
|
||||||
|
with pytest.raises(FileNotFoundError, match="config.json not found"):
|
||||||
|
RewardModelConfig.from_pretrained(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reward_model_config_from_pretrained_roundtrip(tmp_path):
|
||||||
|
"""Round-trip: save a RewardClassifierConfig, reload it, fields must match."""
|
||||||
|
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||||
|
|
||||||
|
original = RewardClassifierConfig(
|
||||||
|
num_classes=3,
|
||||||
|
hidden_dim=128,
|
||||||
|
latent_dim=64,
|
||||||
|
num_cameras=1,
|
||||||
|
learning_rate=5e-4,
|
||||||
|
)
|
||||||
|
original._save_pretrained(tmp_path)
|
||||||
|
|
||||||
|
loaded = RewardModelConfig.from_pretrained(tmp_path)
|
||||||
|
|
||||||
|
assert isinstance(loaded, RewardClassifierConfig)
|
||||||
|
assert loaded.num_classes == 3
|
||||||
|
assert loaded.hidden_dim == 128
|
||||||
|
assert loaded.latent_dim == 64
|
||||||
|
assert loaded.num_cameras == 1
|
||||||
|
assert loaded.learning_rate == 5e-4
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TrainPipelineConfig — reward model training path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_pipeline_config_path_fields_includes_reward_model():
|
||||||
|
"""``--reward_model.path=local/dir`` requires ``reward_model`` to be listed
|
||||||
|
as a draccus path-field on ``TrainPipelineConfig``."""
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
|
||||||
|
fields = TrainPipelineConfig.__get_path_fields__()
|
||||||
|
assert "policy" in fields
|
||||||
|
assert "reward_model" in fields
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_pipeline_config_trainable_config_returns_reward_model_when_set():
|
||||||
|
"""When only ``reward_model`` is set, ``trainable_config`` (used by the
|
||||||
|
trainer for e.g. ``.device``) must return it — not ``None`` from ``policy``."""
|
||||||
|
from lerobot.configs.default import DatasetConfig
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||||
|
|
||||||
|
reward_cfg = RewardClassifierConfig(device="cpu")
|
||||||
|
cfg = TrainPipelineConfig(
|
||||||
|
dataset=DatasetConfig(repo_id="user/repo"),
|
||||||
|
reward_model=reward_cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cfg.is_reward_model_training is True
|
||||||
|
assert cfg.trainable_config is reward_cfg
|
||||||
|
# This is what lerobot_train.py uses to decide force_cpu; ``cfg.policy.device``
|
||||||
|
# would AttributeError here because policy is None.
|
||||||
|
assert cfg.trainable_config.device == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_pipeline_config_trainable_config_returns_policy_when_set():
|
||||||
|
"""Mirror of the reward-model case: when only ``policy`` is set,
|
||||||
|
``trainable_config`` must return it."""
|
||||||
|
from lerobot.configs.default import DatasetConfig
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
|
||||||
|
policy_cfg = DiffusionConfig(device="cpu")
|
||||||
|
cfg = TrainPipelineConfig(
|
||||||
|
dataset=DatasetConfig(repo_id="user/repo"),
|
||||||
|
policy=policy_cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cfg.is_reward_model_training is False
|
||||||
|
assert cfg.trainable_config is policy_cfg
|
||||||
|
assert cfg.trainable_config.device == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
|
||||||
|
# We test the generation side (offline) fully, and the upload side with HfApi
|
||||||
|
# mocked so nothing actually hits the network.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_dummy_reward_model(**config_kwargs):
|
||||||
|
return _DummyHubReward(_DummyHubRewardConfig(**config_kwargs)), _DummyHubRewardConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def _offline_model_card(monkeypatch):
|
||||||
|
"""``ModelCard.validate`` does a live ``POST`` to huggingface.co — bypass it
|
||||||
|
so tests can run offline."""
|
||||||
|
from huggingface_hub import ModelCard
|
||||||
|
|
||||||
|
monkeypatch.setattr(ModelCard, "validate", lambda self, *a, **kw: None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reward_model_generate_model_card_renders_expected_fields(_offline_model_card):
|
||||||
|
"""``generate_model_card`` must produce a card with the right metadata and
|
||||||
|
body, using the dedicated reward-model template."""
|
||||||
|
model, _ = _make_dummy_reward_model(
|
||||||
|
license="mit",
|
||||||
|
tags=["robot", "sim"],
|
||||||
|
)
|
||||||
|
|
||||||
|
card = model.generate_model_card(
|
||||||
|
dataset_repo_id="user/my_dataset",
|
||||||
|
model_type=model.config.type,
|
||||||
|
license=model.config.license,
|
||||||
|
tags=model.config.tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Metadata (YAML header) — ModelCardData fields.
|
||||||
|
assert card.data.license == "mit"
|
||||||
|
assert card.data.library_name == "lerobot"
|
||||||
|
assert card.data.pipeline_tag == "robotics"
|
||||||
|
assert "reward-model" in card.data.tags
|
||||||
|
assert model.config.type in card.data.tags
|
||||||
|
assert card.data.model_name == model.config.type
|
||||||
|
assert card.data.datasets == "user/my_dataset"
|
||||||
|
|
||||||
|
# Body — specific to the reward-model template, NOT the policy one.
|
||||||
|
body = str(card)
|
||||||
|
assert "Reward Model Card" in body
|
||||||
|
assert "This reward model has been trained" in body
|
||||||
|
assert "--reward_model.type=" in body # reward-model-specific usage block
|
||||||
|
|
||||||
|
|
||||||
|
def test_reward_model_generate_model_card_uses_default_license(_offline_model_card):
|
||||||
|
"""When config.license is None the card falls back to apache-2.0."""
|
||||||
|
model, _ = _make_dummy_reward_model()
|
||||||
|
|
||||||
|
card = model.generate_model_card(
|
||||||
|
dataset_repo_id="user/my_dataset",
|
||||||
|
model_type=model.config.type,
|
||||||
|
license=model.config.license,
|
||||||
|
tags=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert card.data.license == "apache-2.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_reward_model_push_model_to_hub_uploads_expected_files(monkeypatch, _offline_model_card):
|
||||||
|
"""``push_model_to_hub`` must:
|
||||||
|
1. create the repo,
|
||||||
|
2. assemble a temp folder with weights + config.json + train_config.json + README.md,
|
||||||
|
3. call ``api.upload_folder`` on that folder.
|
||||||
|
All network calls are mocked.
|
||||||
|
"""
|
||||||
|
from huggingface_hub.constants import CONFIG_NAME
|
||||||
|
|
||||||
|
from lerobot.configs.default import DatasetConfig
|
||||||
|
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||||
|
|
||||||
|
model, _ = _make_dummy_reward_model(
|
||||||
|
repo_id="user/my_reward",
|
||||||
|
license="apache-2.0",
|
||||||
|
)
|
||||||
|
# Point the reward model's train config at a dummy dataset repo.
|
||||||
|
train_cfg = TrainPipelineConfig(
|
||||||
|
dataset=DatasetConfig(repo_id="user/my_dataset"),
|
||||||
|
reward_model=model.config,
|
||||||
|
)
|
||||||
|
|
||||||
|
uploaded: dict = {}
|
||||||
|
fake_commit_info = SimpleNamespace(repo_url=SimpleNamespace(url="https://huggingface.co/user/my_reward"))
|
||||||
|
|
||||||
|
class _FakeHfApi:
|
||||||
|
def create_repo(self, repo_id, private=None, exist_ok=False):
|
||||||
|
uploaded["create_repo_id"] = repo_id
|
||||||
|
uploaded["create_private"] = private
|
||||||
|
return SimpleNamespace(repo_id=repo_id)
|
||||||
|
|
||||||
|
def upload_folder(self, *, repo_id, repo_type, folder_path, commit_message, **_kwargs):
|
||||||
|
uploaded["upload_repo_id"] = repo_id
|
||||||
|
uploaded["upload_repo_type"] = repo_type
|
||||||
|
uploaded["commit_message"] = commit_message
|
||||||
|
# Snapshot files assembled in the temp folder — this is the real
|
||||||
|
# contract we care about.
|
||||||
|
uploaded["files"] = sorted(p.name for p in Path(folder_path).iterdir())
|
||||||
|
return fake_commit_info
|
||||||
|
|
||||||
|
from lerobot.rewards import pretrained as reward_pretrained
|
||||||
|
|
||||||
|
monkeypatch.setattr(reward_pretrained, "HfApi", lambda *a, **kw: _FakeHfApi())
|
||||||
|
|
||||||
|
model.push_model_to_hub(train_cfg)
|
||||||
|
|
||||||
|
assert uploaded["create_repo_id"] == "user/my_reward"
|
||||||
|
assert uploaded["upload_repo_id"] == "user/my_reward"
|
||||||
|
assert uploaded["upload_repo_type"] == "model"
|
||||||
|
assert uploaded["commit_message"] == "Upload reward model weights, train config and readme"
|
||||||
|
# Minimum required files that must be uploaded with a reward model.
|
||||||
|
assert CONFIG_NAME in uploaded["files"] # config.json
|
||||||
|
assert TRAIN_CONFIG_NAME in uploaded["files"] # train_config.json
|
||||||
|
assert "README.md" in uploaded["files"]
|
||||||
|
assert any(name.endswith(".safetensors") for name in uploaded["files"])
|
||||||
@@ -104,8 +104,8 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
|||||||
def mock_clip_model(self):
|
def mock_clip_model(self):
|
||||||
"""Mock CLIP model to avoid loading real weights."""
|
"""Mock CLIP model to avoid loading real weights."""
|
||||||
with (
|
with (
|
||||||
patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls,
|
patch("lerobot.rewards.sarm.processor_sarm.CLIPModel") as mock_model_cls,
|
||||||
patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
|
patch("lerobot.rewards.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
|
||||||
):
|
):
|
||||||
# Mock the CLIP model - return embeddings based on input batch size
|
# Mock the CLIP model - return embeddings based on input batch size
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
@@ -142,7 +142,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def processor_with_mocks(self, mock_clip_model):
|
def processor_with_mocks(self, mock_clip_model):
|
||||||
"""Create a processor with mocked CLIP and dataset metadata for dual mode."""
|
"""Create a processor with mocked CLIP and dataset metadata for dual mode."""
|
||||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
|
||||||
# Dual mode config with both sparse and dense annotations
|
# Dual mode config with both sparse and dense annotations
|
||||||
config = MockConfig(
|
config = MockConfig(
|
||||||
@@ -256,7 +256,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
|||||||
|
|
||||||
def test_call_with_batched_input(self, mock_clip_model):
|
def test_call_with_batched_input(self, mock_clip_model):
|
||||||
"""Test processor __call__ with a batched input (multiple frames) in dual mode."""
|
"""Test processor __call__ with a batched input (multiple frames) in dual mode."""
|
||||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
|
||||||
config = MockConfig(
|
config = MockConfig(
|
||||||
n_obs_steps=8,
|
n_obs_steps=8,
|
||||||
@@ -332,7 +332,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
|||||||
|
|
||||||
def test_targets_increase_with_progress(self, mock_clip_model):
|
def test_targets_increase_with_progress(self, mock_clip_model):
|
||||||
"""Test that both sparse and dense targets increase as frame index progresses."""
|
"""Test that both sparse and dense targets increase as frame index progresses."""
|
||||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
|
||||||
config = MockConfig(
|
config = MockConfig(
|
||||||
n_obs_steps=8,
|
n_obs_steps=8,
|
||||||
@@ -404,7 +404,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
|||||||
|
|
||||||
def test_progress_labels_exact_values(self, mock_clip_model):
|
def test_progress_labels_exact_values(self, mock_clip_model):
|
||||||
"""Test that progress labels (stage.tau) are computed correctly for known positions."""
|
"""Test that progress labels (stage.tau) are computed correctly for known positions."""
|
||||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
|
||||||
# Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
|
# Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
|
||||||
config = MockConfig(
|
config = MockConfig(
|
||||||
@@ -495,7 +495,7 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
|||||||
"""Test that rewind augmentation correctly extends sequence and generates targets."""
|
"""Test that rewind augmentation correctly extends sequence and generates targets."""
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
|
||||||
config = MockConfig(
|
config = MockConfig(
|
||||||
n_obs_steps=8,
|
n_obs_steps=8,
|
||||||
@@ -587,8 +587,8 @@ class TestSARMEncodingProcessorStepEndToEnd:
|
|||||||
|
|
||||||
def test_full_sequence_target_consistency(self, mock_clip_model):
|
def test_full_sequence_target_consistency(self, mock_clip_model):
|
||||||
"""Test that the full sequence of targets is consistent with frame positions."""
|
"""Test that the full sequence of targets is consistent with frame positions."""
|
||||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
from lerobot.rewards.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
from lerobot.policies.sarm.sarm_utils import find_stage_and_tau
|
from lerobot.rewards.sarm.sarm_utils import find_stage_and_tau
|
||||||
|
|
||||||
config = MockConfig(
|
config = MockConfig(
|
||||||
n_obs_steps=8,
|
n_obs_steps=8,
|
||||||
@@ -18,7 +18,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.policies.sarm.sarm_utils import (
|
from lerobot.rewards.sarm.sarm_utils import (
|
||||||
apply_rewind_augmentation,
|
apply_rewind_augmentation,
|
||||||
compute_absolute_indices,
|
compute_absolute_indices,
|
||||||
compute_tau,
|
compute_tau,
|
||||||
@@ -0,0 +1,401 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for the sample weighting infrastructure."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.utils.sample_weighting import (
|
||||||
|
SampleWeighter,
|
||||||
|
SampleWeightingConfig,
|
||||||
|
UniformWeighter,
|
||||||
|
make_sample_weighter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_progress_parquet(tmp_path):
|
||||||
|
"""Create a sample progress parquet file for testing."""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# Create sample progress data for 2 episodes with 10 frames each
|
||||||
|
data = {
|
||||||
|
"index": list(range(20)),
|
||||||
|
"episode_index": [0] * 10 + [1] * 10,
|
||||||
|
"frame_index": list(range(10)) * 2,
|
||||||
|
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
|
||||||
|
}
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
parquet_path = tmp_path / "sarm_progress.parquet"
|
||||||
|
df.to_parquet(parquet_path)
|
||||||
|
return parquet_path
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# SampleWeightingConfig Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_default_values():
|
||||||
|
"""Test default configuration values."""
|
||||||
|
config = SampleWeightingConfig()
|
||||||
|
assert config.type == "rabc"
|
||||||
|
assert config.progress_path is None
|
||||||
|
assert config.head_mode == "sparse"
|
||||||
|
assert config.kappa == 0.01
|
||||||
|
assert config.epsilon == 1e-6
|
||||||
|
assert config.extra_params == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_custom_values():
|
||||||
|
"""Test configuration with custom values."""
|
||||||
|
config = SampleWeightingConfig(
|
||||||
|
type="rabc",
|
||||||
|
progress_path="/path/to/progress.parquet",
|
||||||
|
head_mode="dense",
|
||||||
|
kappa=0.05,
|
||||||
|
epsilon=1e-8,
|
||||||
|
extra_params={"fallback_weight": 0.5},
|
||||||
|
)
|
||||||
|
assert config.type == "rabc"
|
||||||
|
assert config.progress_path == "/path/to/progress.parquet"
|
||||||
|
assert config.head_mode == "dense"
|
||||||
|
assert config.kappa == 0.05
|
||||||
|
assert config.epsilon == 1e-8
|
||||||
|
assert config.extra_params == {"fallback_weight": 0.5}
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_uniform_type():
|
||||||
|
"""Test configuration for uniform weighting."""
|
||||||
|
config = SampleWeightingConfig(type="uniform")
|
||||||
|
assert config.type == "uniform"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# UniformWeighter Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_weighter_inherits_from_sample_weighter():
|
||||||
|
"""Test that UniformWeighter is a SampleWeighter."""
|
||||||
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||||
|
assert isinstance(weighter, SampleWeighter)
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_weighter_compute_batch_weights_with_action_key():
|
||||||
|
"""Test weight computation with 'action' key in batch."""
|
||||||
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||||
|
batch = {"action": torch.randn(8, 10)}
|
||||||
|
|
||||||
|
weights, stats = weighter.compute_batch_weights(batch)
|
||||||
|
|
||||||
|
assert weights.shape == (8,)
|
||||||
|
assert torch.allclose(weights, torch.ones(8))
|
||||||
|
assert stats["mean_weight"] == 1.0
|
||||||
|
assert stats["type"] == "uniform"
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_weighter_compute_batch_weights_with_index_key():
|
||||||
|
"""Test weight computation with 'index' key in batch."""
|
||||||
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||||
|
batch = {"index": torch.arange(16)}
|
||||||
|
|
||||||
|
weights, stats = weighter.compute_batch_weights(batch)
|
||||||
|
|
||||||
|
assert weights.shape == (16,)
|
||||||
|
assert torch.allclose(weights, torch.ones(16))
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
|
||||||
|
"""Test weight computation with no tensor keys (fallback to size 1)."""
|
||||||
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||||
|
batch = {"other_key": "some_value"}
|
||||||
|
|
||||||
|
weights, stats = weighter.compute_batch_weights(batch)
|
||||||
|
|
||||||
|
assert weights.shape == (1,)
|
||||||
|
assert torch.allclose(weights, torch.ones(1))
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
|
||||||
|
"""Test that empty batch raises ValueError."""
|
||||||
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||||
|
batch = {}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="empty batch"):
|
||||||
|
weighter.compute_batch_weights(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
|
||||||
|
"""Test that batch size is determined by scanning all tensor values."""
|
||||||
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||||
|
# Batch with non-standard key containing a tensor
|
||||||
|
batch = {"custom_tensor": torch.randn(7, 3)}
|
||||||
|
|
||||||
|
weights, stats = weighter.compute_batch_weights(batch)
|
||||||
|
|
||||||
|
assert weights.shape == (7,)
|
||||||
|
assert torch.allclose(weights, torch.ones(7))
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_weighter_compute_batch_weights_on_cuda():
|
||||||
|
"""Test that weights are placed on the correct device."""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
pytest.skip("CUDA not available")
|
||||||
|
|
||||||
|
weighter = UniformWeighter(device=torch.device("cuda"))
|
||||||
|
batch = {"action": torch.randn(4, 10)}
|
||||||
|
|
||||||
|
weights, _ = weighter.compute_batch_weights(batch)
|
||||||
|
|
||||||
|
assert weights.device.type == "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_weighter_get_stats():
|
||||||
|
"""Test get_stats returns expected structure."""
|
||||||
|
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||||
|
stats = weighter.get_stats()
|
||||||
|
|
||||||
|
assert stats == {"type": "uniform"}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# make_sample_weighter Factory Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_returns_none_for_none_config():
|
||||||
|
"""Test that None config returns None weighter."""
|
||||||
|
policy = Mock()
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
result = make_sample_weighter(None, policy, device)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_creates_uniform_weighter():
|
||||||
|
"""Test creation of UniformWeighter."""
|
||||||
|
config = SampleWeightingConfig(type="uniform")
|
||||||
|
policy = Mock()
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
weighter = make_sample_weighter(config, policy, device)
|
||||||
|
|
||||||
|
assert isinstance(weighter, UniformWeighter)
|
||||||
|
assert isinstance(weighter, SampleWeighter)
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_raises_for_unknown_type():
|
||||||
|
"""Test that unknown type raises ValueError."""
|
||||||
|
config = SampleWeightingConfig(type="unknown_type")
|
||||||
|
policy = Mock()
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown sample weighting type"):
|
||||||
|
make_sample_weighter(config, policy, device)
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_rabc_requires_chunk_size():
|
||||||
|
"""Test that RABC weighter requires chunk_size in policy config."""
|
||||||
|
config = SampleWeightingConfig(
|
||||||
|
type="rabc",
|
||||||
|
progress_path="/path/to/progress.parquet",
|
||||||
|
)
|
||||||
|
policy = Mock()
|
||||||
|
policy.config = Mock()
|
||||||
|
policy.config.chunk_size = None # No chunk_size
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="chunk_size"):
|
||||||
|
make_sample_weighter(config, policy, device)
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_rabc_requires_progress_path_or_dataset_info():
|
||||||
|
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
|
||||||
|
config = SampleWeightingConfig(
|
||||||
|
type="rabc",
|
||||||
|
progress_path=None, # No progress path
|
||||||
|
)
|
||||||
|
policy = Mock()
|
||||||
|
policy.config = Mock()
|
||||||
|
policy.config.chunk_size = 50
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# Should fail when no progress_path AND no dataset info
|
||||||
|
with pytest.raises(ValueError, match="progress_path"):
|
||||||
|
make_sample_weighter(config, policy, device)
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
|
||||||
|
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
|
||||||
|
config = SampleWeightingConfig(
|
||||||
|
type="rabc",
|
||||||
|
progress_path=None, # Not provided, should auto-detect
|
||||||
|
)
|
||||||
|
policy = Mock()
|
||||||
|
policy.config = Mock()
|
||||||
|
policy.config.chunk_size = 5
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# The parquet file is at sample_progress_parquet, get its parent directory
|
||||||
|
dataset_root = sample_progress_parquet.parent
|
||||||
|
weighter = make_sample_weighter(
|
||||||
|
config,
|
||||||
|
policy,
|
||||||
|
device,
|
||||||
|
dataset_root=str(dataset_root),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert weighter is not None
|
||||||
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||||
|
|
||||||
|
assert isinstance(weighter, RABCWeights)
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_rabc_auto_detects_from_repo_id():
|
||||||
|
"""Test that RABC weighter constructs HF path from repo_id."""
|
||||||
|
config = SampleWeightingConfig(
|
||||||
|
type="rabc",
|
||||||
|
progress_path=None, # Not provided, should auto-detect
|
||||||
|
)
|
||||||
|
policy = Mock()
|
||||||
|
policy.config = Mock()
|
||||||
|
policy.config.chunk_size = 50
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
# This will construct the path but fail when trying to load (file doesn't exist)
|
||||||
|
# We just verify it doesn't raise the "progress_path required" error
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
make_sample_weighter(
|
||||||
|
config,
|
||||||
|
policy,
|
||||||
|
device,
|
||||||
|
dataset_repo_id="test-user/test-dataset",
|
||||||
|
)
|
||||||
|
# Should NOT be the "progress_path required" error - it should try to load the file
|
||||||
|
assert (
|
||||||
|
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Integration Tests with RABCWeights
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
|
||||||
|
"""Test that RABCWeights inherits from SampleWeighter."""
|
||||||
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||||
|
|
||||||
|
weighter = RABCWeights(
|
||||||
|
progress_path=sample_progress_parquet,
|
||||||
|
chunk_size=5,
|
||||||
|
head_mode="sparse",
|
||||||
|
)
|
||||||
|
assert isinstance(weighter, SampleWeighter)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rabc_compute_batch_weights(sample_progress_parquet):
|
||||||
|
"""Test RABCWeights.compute_batch_weights returns correct structure."""
|
||||||
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||||
|
|
||||||
|
weighter = RABCWeights(
|
||||||
|
progress_path=sample_progress_parquet,
|
||||||
|
chunk_size=5,
|
||||||
|
head_mode="sparse",
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
|
||||||
|
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||||
|
weights, stats = weighter.compute_batch_weights(batch)
|
||||||
|
|
||||||
|
assert isinstance(weights, torch.Tensor)
|
||||||
|
assert weights.shape == (4,)
|
||||||
|
assert isinstance(stats, dict)
|
||||||
|
assert "mean_weight" in stats
|
||||||
|
|
||||||
|
|
||||||
|
def test_rabc_get_stats(sample_progress_parquet):
|
||||||
|
"""Test RABCWeights.get_stats returns expected structure."""
|
||||||
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||||
|
|
||||||
|
weighter = RABCWeights(
|
||||||
|
progress_path=sample_progress_parquet,
|
||||||
|
chunk_size=5,
|
||||||
|
head_mode="sparse",
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = weighter.get_stats()
|
||||||
|
|
||||||
|
assert stats["type"] == "rabc"
|
||||||
|
assert "num_frames" in stats
|
||||||
|
assert "chunk_size" in stats
|
||||||
|
assert stats["chunk_size"] == 5
|
||||||
|
assert "head_mode" in stats
|
||||||
|
assert stats["head_mode"] == "sparse"
|
||||||
|
assert "delta_mean" in stats
|
||||||
|
assert "delta_std" in stats
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_creates_rabc_weighter(sample_progress_parquet):
|
||||||
|
"""Test factory creates RABCWeights with valid config."""
|
||||||
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||||
|
|
||||||
|
config = SampleWeightingConfig(
|
||||||
|
type="rabc",
|
||||||
|
progress_path=str(sample_progress_parquet),
|
||||||
|
head_mode="sparse",
|
||||||
|
kappa=0.01,
|
||||||
|
)
|
||||||
|
policy = Mock()
|
||||||
|
policy.config = Mock()
|
||||||
|
policy.config.chunk_size = 5
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
weighter = make_sample_weighter(config, policy, device)
|
||||||
|
|
||||||
|
assert isinstance(weighter, RABCWeights)
|
||||||
|
assert isinstance(weighter, SampleWeighter)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rabc_weights_normalization(sample_progress_parquet):
|
||||||
|
"""Test that RABCWeights normalizes weights to sum to batch_size."""
|
||||||
|
from lerobot.rewards.sarm.rabc import RABCWeights
|
||||||
|
|
||||||
|
weighter = RABCWeights(
|
||||||
|
progress_path=sample_progress_parquet,
|
||||||
|
chunk_size=5,
|
||||||
|
head_mode="sparse",
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
|
||||||
|
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||||
|
weights, _ = weighter.compute_batch_weights(batch)
|
||||||
|
|
||||||
|
# Weights should be normalized to sum approximately to batch_size
|
||||||
|
batch_size = 4
|
||||||
|
assert abs(weights.sum().item() - batch_size) < 0.1
|
||||||
Reference in New Issue
Block a user