mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(train): migrate legacy RA-BC fields in train config loading (#3480)
This commit is contained in:
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
import builtins
|
||||
import datetime as dt
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -35,6 +37,42 @@ from .rewards import RewardModelConfig
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
||||
legacy_fields = (
|
||||
"use_rabc",
|
||||
"rabc_progress_path",
|
||||
"rabc_kappa",
|
||||
"rabc_epsilon",
|
||||
"rabc_head_mode",
|
||||
)
|
||||
if not any(key in config for key in legacy_fields):
|
||||
return None
|
||||
|
||||
migrated_config = dict(config)
|
||||
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
||||
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
||||
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
||||
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
||||
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
||||
|
||||
# New configs may already define sample_weighting explicitly. In that case,
|
||||
# legacy fields are ignored after being stripped from the payload.
|
||||
if migrated_config.get("sample_weighting") is None and use_rabc:
|
||||
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
||||
if rabc_progress_path is not None:
|
||||
sample_weighting["progress_path"] = rabc_progress_path
|
||||
if rabc_kappa is not None:
|
||||
sample_weighting["kappa"] = rabc_kappa
|
||||
if rabc_epsilon is not None:
|
||||
sample_weighting["epsilon"] = rabc_epsilon
|
||||
if rabc_head_mode is not None:
|
||||
sample_weighting["head_mode"] = rabc_head_mode
|
||||
migrated_config["sample_weighting"] = sample_weighting
|
||||
|
||||
return migrated_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
@@ -218,6 +256,15 @@ class TrainPipelineConfig(HubMixin):
|
||||
) from e
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
if config_file is not None:
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||
if migrated_config is not None:
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(migrated_config, f)
|
||||
config_file = f.name
|
||||
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
"""Tests for the reward model base classes and registry."""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
@@ -251,6 +252,79 @@ def test_train_pipeline_config_trainable_config_returns_policy_when_set():
|
||||
assert cfg.trainable_config.device == "cpu"
|
||||
|
||||
|
||||
def test_train_pipeline_config_from_pretrained_migrates_legacy_rabc_fields(tmp_path):
|
||||
"""Legacy top-level RA-BC fields should be migrated into ``sample_weighting``."""
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
|
||||
cfg = TrainPipelineConfig(
|
||||
dataset=DatasetConfig(repo_id="user/repo"),
|
||||
policy=DiffusionConfig(device="cpu"),
|
||||
)
|
||||
cfg._save_pretrained(tmp_path)
|
||||
|
||||
config_path = tmp_path / TRAIN_CONFIG_NAME
|
||||
with open(config_path) as f:
|
||||
payload = json.load(f)
|
||||
|
||||
payload.pop("sample_weighting", None)
|
||||
payload.update(
|
||||
{
|
||||
"use_rabc": True,
|
||||
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
|
||||
"rabc_kappa": 0.05,
|
||||
"rabc_epsilon": 1e-5,
|
||||
"rabc_head_mode": "dense",
|
||||
}
|
||||
)
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(payload, f)
|
||||
|
||||
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
|
||||
|
||||
assert loaded.sample_weighting is not None
|
||||
assert loaded.sample_weighting.type == "rabc"
|
||||
assert loaded.sample_weighting.progress_path == "hf://datasets/user/repo/sarm_progress.parquet"
|
||||
assert loaded.sample_weighting.kappa == 0.05
|
||||
assert loaded.sample_weighting.epsilon == 1e-5
|
||||
assert loaded.sample_weighting.head_mode == "dense"
|
||||
|
||||
|
||||
def test_train_pipeline_config_from_pretrained_strips_legacy_rabc_when_disabled(tmp_path):
|
||||
"""Legacy RA-BC fields should be ignored when ``use_rabc`` was false."""
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
|
||||
cfg = TrainPipelineConfig(
|
||||
dataset=DatasetConfig(repo_id="user/repo"),
|
||||
policy=DiffusionConfig(device="cpu"),
|
||||
)
|
||||
cfg._save_pretrained(tmp_path)
|
||||
|
||||
config_path = tmp_path / TRAIN_CONFIG_NAME
|
||||
with open(config_path) as f:
|
||||
payload = json.load(f)
|
||||
|
||||
payload.pop("sample_weighting", None)
|
||||
payload.update(
|
||||
{
|
||||
"use_rabc": False,
|
||||
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
|
||||
"rabc_kappa": 0.05,
|
||||
"rabc_epsilon": 1e-5,
|
||||
"rabc_head_mode": "dense",
|
||||
}
|
||||
)
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(payload, f)
|
||||
|
||||
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
|
||||
|
||||
assert loaded.sample_weighting is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
|
||||
# We test the generation side (offline) fully, and the upload side with HfApi
|
||||
|
||||
Reference in New Issue
Block a user