mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
Compare commits
3 Commits
cb0a944941
...
b5f65e5332
| Author | SHA1 | Date | |
|---|---|---|---|
| b5f65e5332 | |||
| cd6b43ea7a | |||
| 2236bbe7a3 |
@@ -1,3 +1,4 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
import builtins
|
||||
import datetime as dt
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -35,6 +37,42 @@ from .rewards import RewardModelConfig
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
||||
legacy_fields = (
|
||||
"use_rabc",
|
||||
"rabc_progress_path",
|
||||
"rabc_kappa",
|
||||
"rabc_epsilon",
|
||||
"rabc_head_mode",
|
||||
)
|
||||
if not any(key in config for key in legacy_fields):
|
||||
return None
|
||||
|
||||
migrated_config = dict(config)
|
||||
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
||||
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
||||
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
||||
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
||||
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
||||
|
||||
# New configs may already define sample_weighting explicitly. In that case,
|
||||
# legacy fields are ignored after being stripped from the payload.
|
||||
if migrated_config.get("sample_weighting") is None and use_rabc:
|
||||
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
||||
if rabc_progress_path is not None:
|
||||
sample_weighting["progress_path"] = rabc_progress_path
|
||||
if rabc_kappa is not None:
|
||||
sample_weighting["kappa"] = rabc_kappa
|
||||
if rabc_epsilon is not None:
|
||||
sample_weighting["epsilon"] = rabc_epsilon
|
||||
if rabc_head_mode is not None:
|
||||
sample_weighting["head_mode"] = rabc_head_mode
|
||||
migrated_config["sample_weighting"] = sample_weighting
|
||||
|
||||
return migrated_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
@@ -218,6 +256,15 @@ class TrainPipelineConfig(HubMixin):
|
||||
) from e
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
if config_file is not None:
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||
if migrated_config is not None:
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(migrated_config, f)
|
||||
config_file = f.name
|
||||
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .modeling_sarm import SARMRewardModel
|
||||
from .processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
__all__ = ["SARMConfig", "SARMRewardModel", "make_sarm_pre_post_processors"]
|
||||
|
||||
@@ -27,7 +27,7 @@ from threading import Event
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PreTrainedConfig
|
||||
from lerobot.configs import FeatureType
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
aggregate_pipeline_dataset_features,
|
||||
@@ -178,33 +178,26 @@ def build_rollout_context(
|
||||
policy_config = cfg.policy
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
for attr in ("device", "use_amp"):
|
||||
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
|
||||
cli_val = getattr(cfg.policy, attr)
|
||||
if cli_val is not None:
|
||||
setattr(full_config, attr, cli_val)
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if hasattr(full_config, "compile_model"):
|
||||
full_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if full_config.type == "vqbet" and cfg.device == "mps":
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
if full_config.use_peft:
|
||||
if policy_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = cfg.policy.pretrained_path
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config)
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
@@ -315,7 +308,9 @@ def build_rollout_context(
|
||||
# Validate visual features if no rename_map is active
|
||||
rename_map = cfg.rename_map
|
||||
if not rename_map:
|
||||
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
expected_visuals = {
|
||||
k for k, v in policy_config.input_features.items() if v.type == FeatureType.VISUAL
|
||||
}
|
||||
provided_visuals = {
|
||||
f"observation.images.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
"""Tests for the reward model base classes and registry."""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
@@ -251,6 +252,79 @@ def test_train_pipeline_config_trainable_config_returns_policy_when_set():
|
||||
assert cfg.trainable_config.device == "cpu"
|
||||
|
||||
|
||||
def test_train_pipeline_config_from_pretrained_migrates_legacy_rabc_fields(tmp_path):
|
||||
"""Legacy top-level RA-BC fields should be migrated into ``sample_weighting``."""
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
|
||||
cfg = TrainPipelineConfig(
|
||||
dataset=DatasetConfig(repo_id="user/repo"),
|
||||
policy=DiffusionConfig(device="cpu"),
|
||||
)
|
||||
cfg._save_pretrained(tmp_path)
|
||||
|
||||
config_path = tmp_path / TRAIN_CONFIG_NAME
|
||||
with open(config_path) as f:
|
||||
payload = json.load(f)
|
||||
|
||||
payload.pop("sample_weighting", None)
|
||||
payload.update(
|
||||
{
|
||||
"use_rabc": True,
|
||||
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
|
||||
"rabc_kappa": 0.05,
|
||||
"rabc_epsilon": 1e-5,
|
||||
"rabc_head_mode": "dense",
|
||||
}
|
||||
)
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(payload, f)
|
||||
|
||||
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
|
||||
|
||||
assert loaded.sample_weighting is not None
|
||||
assert loaded.sample_weighting.type == "rabc"
|
||||
assert loaded.sample_weighting.progress_path == "hf://datasets/user/repo/sarm_progress.parquet"
|
||||
assert loaded.sample_weighting.kappa == 0.05
|
||||
assert loaded.sample_weighting.epsilon == 1e-5
|
||||
assert loaded.sample_weighting.head_mode == "dense"
|
||||
|
||||
|
||||
def test_train_pipeline_config_from_pretrained_strips_legacy_rabc_when_disabled(tmp_path):
|
||||
"""Legacy RA-BC fields should be ignored when ``use_rabc`` was false."""
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
|
||||
cfg = TrainPipelineConfig(
|
||||
dataset=DatasetConfig(repo_id="user/repo"),
|
||||
policy=DiffusionConfig(device="cpu"),
|
||||
)
|
||||
cfg._save_pretrained(tmp_path)
|
||||
|
||||
config_path = tmp_path / TRAIN_CONFIG_NAME
|
||||
with open(config_path) as f:
|
||||
payload = json.load(f)
|
||||
|
||||
payload.pop("sample_weighting", None)
|
||||
payload.update(
|
||||
{
|
||||
"use_rabc": False,
|
||||
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
|
||||
"rabc_kappa": 0.05,
|
||||
"rabc_epsilon": 1e-5,
|
||||
"rabc_head_mode": "dense",
|
||||
}
|
||||
)
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(payload, f)
|
||||
|
||||
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
|
||||
|
||||
assert loaded.sample_weighting is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
|
||||
# We test the generation side (offline) fully, and the upload side with HfApi
|
||||
|
||||
Reference in New Issue
Block a user