mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
Compare commits
3 Commits
cb0a944941
...
b5f65e5332
| Author | SHA1 | Date | |
|---|---|---|---|
| b5f65e5332 | |||
| cd6b43ea7a | |||
| 2236bbe7a3 |
@@ -1,3 +1,4 @@
|
|||||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||||
|
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
|
||||||
include src/lerobot/datasets/card_template.md
|
include src/lerobot/datasets/card_template.md
|
||||||
include src/lerobot/envs/metaworld_config.json
|
include src/lerobot/envs/metaworld_config.json
|
||||||
|
|||||||
@@ -13,7 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import builtins
|
import builtins
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -35,6 +37,42 @@ from .rewards import RewardModelConfig
|
|||||||
TRAIN_CONFIG_NAME = "train_config.json"
|
TRAIN_CONFIG_NAME = "train_config.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
||||||
|
legacy_fields = (
|
||||||
|
"use_rabc",
|
||||||
|
"rabc_progress_path",
|
||||||
|
"rabc_kappa",
|
||||||
|
"rabc_epsilon",
|
||||||
|
"rabc_head_mode",
|
||||||
|
)
|
||||||
|
if not any(key in config for key in legacy_fields):
|
||||||
|
return None
|
||||||
|
|
||||||
|
migrated_config = dict(config)
|
||||||
|
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
||||||
|
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
||||||
|
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
||||||
|
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
||||||
|
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
||||||
|
|
||||||
|
# New configs may already define sample_weighting explicitly. In that case,
|
||||||
|
# legacy fields are ignored after being stripped from the payload.
|
||||||
|
if migrated_config.get("sample_weighting") is None and use_rabc:
|
||||||
|
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
||||||
|
if rabc_progress_path is not None:
|
||||||
|
sample_weighting["progress_path"] = rabc_progress_path
|
||||||
|
if rabc_kappa is not None:
|
||||||
|
sample_weighting["kappa"] = rabc_kappa
|
||||||
|
if rabc_epsilon is not None:
|
||||||
|
sample_weighting["epsilon"] = rabc_epsilon
|
||||||
|
if rabc_head_mode is not None:
|
||||||
|
sample_weighting["head_mode"] = rabc_head_mode
|
||||||
|
migrated_config["sample_weighting"] = sample_weighting
|
||||||
|
|
||||||
|
return migrated_config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineConfig(HubMixin):
|
class TrainPipelineConfig(HubMixin):
|
||||||
dataset: DatasetConfig
|
dataset: DatasetConfig
|
||||||
@@ -218,6 +256,15 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
cli_args = kwargs.pop("cli_args", [])
|
cli_args = kwargs.pop("cli_args", [])
|
||||||
|
if config_file is not None:
|
||||||
|
with open(config_file) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||||
|
if migrated_config is not None:
|
||||||
|
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||||
|
json.dump(migrated_config, f)
|
||||||
|
config_file = f.name
|
||||||
|
|
||||||
with draccus.config_type("json"):
|
with draccus.config_type("json"):
|
||||||
return draccus.parse(cls, config_file, args=cli_args)
|
return draccus.parse(cls, config_file, args=cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .configuration_sarm import SARMConfig
|
||||||
|
from .modeling_sarm import SARMRewardModel
|
||||||
|
from .processor_sarm import make_sarm_pre_post_processors
|
||||||
|
|
||||||
|
__all__ = ["SARMConfig", "SARMRewardModel", "make_sarm_pre_post_processors"]
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from threading import Event
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs import FeatureType, PreTrainedConfig
|
from lerobot.configs import FeatureType
|
||||||
from lerobot.datasets import (
|
from lerobot.datasets import (
|
||||||
LeRobotDataset,
|
LeRobotDataset,
|
||||||
aggregate_pipeline_dataset_features,
|
aggregate_pipeline_dataset_features,
|
||||||
@@ -178,33 +178,26 @@ def build_rollout_context(
|
|||||||
policy_config = cfg.policy
|
policy_config = cfg.policy
|
||||||
policy_class = get_policy_class(policy_config.type)
|
policy_class = get_policy_class(policy_config.type)
|
||||||
|
|
||||||
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
if hasattr(policy_config, "compile_model"):
|
||||||
for attr in ("device", "use_amp"):
|
policy_config.compile_model = cfg.use_torch_compile
|
||||||
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
|
|
||||||
cli_val = getattr(cfg.policy, attr)
|
|
||||||
if cli_val is not None:
|
|
||||||
setattr(full_config, attr, cli_val)
|
|
||||||
|
|
||||||
if hasattr(full_config, "compile_model"):
|
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||||
full_config.compile_model = cfg.use_torch_compile
|
|
||||||
|
|
||||||
if full_config.type == "vqbet" and cfg.device == "mps":
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Current implementation of VQBeT does not support `mps` backend. "
|
"Current implementation of VQBeT does not support `mps` backend. "
|
||||||
"Please use `cpu` or `cuda` backend."
|
"Please use `cpu` or `cuda` backend."
|
||||||
)
|
)
|
||||||
|
|
||||||
if full_config.use_peft:
|
if policy_config.use_peft:
|
||||||
from peft import PeftConfig, PeftModel
|
from peft import PeftConfig, PeftModel
|
||||||
|
|
||||||
peft_path = cfg.policy.pretrained_path
|
peft_path = policy_config.pretrained_path
|
||||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||||
policy = policy_class.from_pretrained(
|
policy = policy_class.from_pretrained(
|
||||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config
|
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||||
)
|
)
|
||||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||||
else:
|
else:
|
||||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config)
|
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||||
|
|
||||||
if is_rtc:
|
if is_rtc:
|
||||||
policy.config.rtc_config = cfg.inference.rtc
|
policy.config.rtc_config = cfg.inference.rtc
|
||||||
@@ -315,7 +308,9 @@ def build_rollout_context(
|
|||||||
# Validate visual features if no rename_map is active
|
# Validate visual features if no rename_map is active
|
||||||
rename_map = cfg.rename_map
|
rename_map = cfg.rename_map
|
||||||
if not rename_map:
|
if not rename_map:
|
||||||
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
|
expected_visuals = {
|
||||||
|
k for k, v in policy_config.input_features.items() if v.type == FeatureType.VISUAL
|
||||||
|
}
|
||||||
provided_visuals = {
|
provided_visuals = {
|
||||||
f"observation.images.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
|
f"observation.images.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
"""Tests for the reward model base classes and registry."""
|
"""Tests for the reward model base classes and registry."""
|
||||||
|
|
||||||
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
@@ -251,6 +252,79 @@ def test_train_pipeline_config_trainable_config_returns_policy_when_set():
|
|||||||
assert cfg.trainable_config.device == "cpu"
|
assert cfg.trainable_config.device == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_pipeline_config_from_pretrained_migrates_legacy_rabc_fields(tmp_path):
|
||||||
|
"""Legacy top-level RA-BC fields should be migrated into ``sample_weighting``."""
|
||||||
|
from lerobot.configs.default import DatasetConfig
|
||||||
|
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||||
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
|
||||||
|
cfg = TrainPipelineConfig(
|
||||||
|
dataset=DatasetConfig(repo_id="user/repo"),
|
||||||
|
policy=DiffusionConfig(device="cpu"),
|
||||||
|
)
|
||||||
|
cfg._save_pretrained(tmp_path)
|
||||||
|
|
||||||
|
config_path = tmp_path / TRAIN_CONFIG_NAME
|
||||||
|
with open(config_path) as f:
|
||||||
|
payload = json.load(f)
|
||||||
|
|
||||||
|
payload.pop("sample_weighting", None)
|
||||||
|
payload.update(
|
||||||
|
{
|
||||||
|
"use_rabc": True,
|
||||||
|
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
|
||||||
|
"rabc_kappa": 0.05,
|
||||||
|
"rabc_epsilon": 1e-5,
|
||||||
|
"rabc_head_mode": "dense",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(payload, f)
|
||||||
|
|
||||||
|
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
|
||||||
|
|
||||||
|
assert loaded.sample_weighting is not None
|
||||||
|
assert loaded.sample_weighting.type == "rabc"
|
||||||
|
assert loaded.sample_weighting.progress_path == "hf://datasets/user/repo/sarm_progress.parquet"
|
||||||
|
assert loaded.sample_weighting.kappa == 0.05
|
||||||
|
assert loaded.sample_weighting.epsilon == 1e-5
|
||||||
|
assert loaded.sample_weighting.head_mode == "dense"
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_pipeline_config_from_pretrained_strips_legacy_rabc_when_disabled(tmp_path):
|
||||||
|
"""Legacy RA-BC fields should be ignored when ``use_rabc`` was false."""
|
||||||
|
from lerobot.configs.default import DatasetConfig
|
||||||
|
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||||
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
|
||||||
|
cfg = TrainPipelineConfig(
|
||||||
|
dataset=DatasetConfig(repo_id="user/repo"),
|
||||||
|
policy=DiffusionConfig(device="cpu"),
|
||||||
|
)
|
||||||
|
cfg._save_pretrained(tmp_path)
|
||||||
|
|
||||||
|
config_path = tmp_path / TRAIN_CONFIG_NAME
|
||||||
|
with open(config_path) as f:
|
||||||
|
payload = json.load(f)
|
||||||
|
|
||||||
|
payload.pop("sample_weighting", None)
|
||||||
|
payload.update(
|
||||||
|
{
|
||||||
|
"use_rabc": False,
|
||||||
|
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
|
||||||
|
"rabc_kappa": 0.05,
|
||||||
|
"rabc_epsilon": 1e-5,
|
||||||
|
"rabc_head_mode": "dense",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(payload, f)
|
||||||
|
|
||||||
|
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
|
||||||
|
|
||||||
|
assert loaded.sample_weighting is None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
|
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
|
||||||
# We test the generation side (offline) fully, and the upload side with HfApi
|
# We test the generation side (offline) fully, and the upload side with HfApi
|
||||||
|
|||||||
Reference in New Issue
Block a user