Compare commits

...

3 Commits

Author SHA1 Message Date
Khalil Meftah b5f65e5332 Expose sarm package API and ship reward model card template (#3477)
* chore: List lerobot_rewardmodel_modelcard_template.md in MANIFEST.in

* chore: export SARMConfig, SARMRewardModel, and make_sarm_pre_post_processors from rewards.sarm.
2026-04-29 16:17:16 +02:00
Khalil Meftah cd6b43ea7a fix(train): migrate legacy RA-BC fields in train config loading (#3480) 2026-04-29 16:17:00 +02:00
Steven Palma 2236bbe7a3 fix(rollout): propagate policy-specific CLI config paramaters (#3483)
Co-authored-by: Maxime Ellerbach <maxime.ellerbach@huggingface.co>
2026-04-29 16:13:10 +02:00
5 changed files with 152 additions and 16 deletions
+1
View File
@@ -1,3 +1,4 @@
include src/lerobot/templates/lerobot_modelcard_template.md include src/lerobot/templates/lerobot_modelcard_template.md
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
include src/lerobot/datasets/card_template.md include src/lerobot/datasets/card_template.md
include src/lerobot/envs/metaworld_config.json include src/lerobot/envs/metaworld_config.json
+47
View File
@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
import builtins import builtins
import datetime as dt import datetime as dt
import json
import os import os
import tempfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -35,6 +37,42 @@ from .rewards import RewardModelConfig
TRAIN_CONFIG_NAME = "train_config.json" TRAIN_CONFIG_NAME = "train_config.json"
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
legacy_fields = (
"use_rabc",
"rabc_progress_path",
"rabc_kappa",
"rabc_epsilon",
"rabc_head_mode",
)
if not any(key in config for key in legacy_fields):
return None
migrated_config = dict(config)
use_rabc = bool(migrated_config.pop("use_rabc", False))
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
rabc_kappa = migrated_config.pop("rabc_kappa", None)
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
# New configs may already define sample_weighting explicitly. In that case,
# legacy fields are ignored after being stripped from the payload.
if migrated_config.get("sample_weighting") is None and use_rabc:
sample_weighting: dict[str, Any] = {"type": "rabc"}
if rabc_progress_path is not None:
sample_weighting["progress_path"] = rabc_progress_path
if rabc_kappa is not None:
sample_weighting["kappa"] = rabc_kappa
if rabc_epsilon is not None:
sample_weighting["epsilon"] = rabc_epsilon
if rabc_head_mode is not None:
sample_weighting["head_mode"] = rabc_head_mode
migrated_config["sample_weighting"] = sample_weighting
return migrated_config
@dataclass @dataclass
class TrainPipelineConfig(HubMixin): class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig dataset: DatasetConfig
@@ -218,6 +256,15 @@ class TrainPipelineConfig(HubMixin):
) from e ) from e
cli_args = kwargs.pop("cli_args", []) cli_args = kwargs.pop("cli_args", [])
if config_file is not None:
with open(config_file) as f:
config = json.load(f)
migrated_config = _migrate_legacy_rabc_fields(config)
if migrated_config is not None:
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
json.dump(migrated_config, f)
config_file = f.name
with draccus.config_type("json"): with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args) return draccus.parse(cls, config_file, args=cli_args)
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_sarm import SARMConfig
from .modeling_sarm import SARMRewardModel
from .processor_sarm import make_sarm_pre_post_processors
__all__ = ["SARMConfig", "SARMRewardModel", "make_sarm_pre_post_processors"]
+11 -16
View File
@@ -27,7 +27,7 @@ from threading import Event
import torch import torch
from lerobot.configs import FeatureType, PreTrainedConfig from lerobot.configs import FeatureType
from lerobot.datasets import ( from lerobot.datasets import (
LeRobotDataset, LeRobotDataset,
aggregate_pipeline_dataset_features, aggregate_pipeline_dataset_features,
@@ -178,33 +178,26 @@ def build_rollout_context(
policy_config = cfg.policy policy_config = cfg.policy
policy_class = get_policy_class(policy_config.type) policy_class = get_policy_class(policy_config.type)
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) if hasattr(policy_config, "compile_model"):
for attr in ("device", "use_amp"): policy_config.compile_model = cfg.use_torch_compile
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
cli_val = getattr(cfg.policy, attr)
if cli_val is not None:
setattr(full_config, attr, cli_val)
if hasattr(full_config, "compile_model"): if policy_config.type == "vqbet" and cfg.device == "mps":
full_config.compile_model = cfg.use_torch_compile
if full_config.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError( raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. " "Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend." "Please use `cpu` or `cuda` backend."
) )
if full_config.use_peft: if policy_config.use_peft:
from peft import PeftConfig, PeftModel from peft import PeftConfig, PeftModel
peft_path = cfg.policy.pretrained_path peft_path = policy_config.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_path) peft_config = PeftConfig.from_pretrained(peft_path)
policy = policy_class.from_pretrained( policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
) )
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config) policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
else: else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config) policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
if is_rtc: if is_rtc:
policy.config.rtc_config = cfg.inference.rtc policy.config.rtc_config = cfg.inference.rtc
@@ -315,7 +308,9 @@ def build_rollout_context(
# Validate visual features if no rename_map is active # Validate visual features if no rename_map is active
rename_map = cfg.rename_map rename_map = cfg.rename_map
if not rename_map: if not rename_map:
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL} expected_visuals = {
k for k, v in policy_config.input_features.items() if v.type == FeatureType.VISUAL
}
provided_visuals = { provided_visuals = {
f"observation.images.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple) f"observation.images.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
} }
+74
View File
@@ -14,6 +14,7 @@
"""Tests for the reward model base classes and registry.""" """Tests for the reward model base classes and registry."""
import json
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
@@ -251,6 +252,79 @@ def test_train_pipeline_config_trainable_config_returns_policy_when_set():
assert cfg.trainable_config.device == "cpu" assert cfg.trainable_config.device == "cpu"
def test_train_pipeline_config_from_pretrained_migrates_legacy_rabc_fields(tmp_path):
"""Legacy top-level RA-BC fields should be migrated into ``sample_weighting``."""
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/repo"),
policy=DiffusionConfig(device="cpu"),
)
cfg._save_pretrained(tmp_path)
config_path = tmp_path / TRAIN_CONFIG_NAME
with open(config_path) as f:
payload = json.load(f)
payload.pop("sample_weighting", None)
payload.update(
{
"use_rabc": True,
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
"rabc_kappa": 0.05,
"rabc_epsilon": 1e-5,
"rabc_head_mode": "dense",
}
)
with open(config_path, "w") as f:
json.dump(payload, f)
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
assert loaded.sample_weighting is not None
assert loaded.sample_weighting.type == "rabc"
assert loaded.sample_weighting.progress_path == "hf://datasets/user/repo/sarm_progress.parquet"
assert loaded.sample_weighting.kappa == 0.05
assert loaded.sample_weighting.epsilon == 1e-5
assert loaded.sample_weighting.head_mode == "dense"
def test_train_pipeline_config_from_pretrained_strips_legacy_rabc_when_disabled(tmp_path):
"""Legacy RA-BC fields should be ignored when ``use_rabc`` was false."""
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/repo"),
policy=DiffusionConfig(device="cpu"),
)
cfg._save_pretrained(tmp_path)
config_path = tmp_path / TRAIN_CONFIG_NAME
with open(config_path) as f:
payload = json.load(f)
payload.pop("sample_weighting", None)
payload.update(
{
"use_rabc": False,
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
"rabc_kappa": 0.05,
"rabc_epsilon": 1e-5,
"rabc_head_mode": "dense",
}
)
with open(config_path, "w") as f:
json.dump(payload, f)
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
assert loaded.sample_weighting is None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card. # PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
# We test the generation side (offline) fully, and the upload side with HfApi # We test the generation side (offline) fully, and the upload side with HfApi