mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
fix(train): migrate legacy RA-BC fields in train config loading (#3480)
This commit is contained in:
@@ -13,7 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import builtins
|
import builtins
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -35,6 +37,42 @@ from .rewards import RewardModelConfig
|
|||||||
TRAIN_CONFIG_NAME = "train_config.json"
|
TRAIN_CONFIG_NAME = "train_config.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
||||||
|
legacy_fields = (
|
||||||
|
"use_rabc",
|
||||||
|
"rabc_progress_path",
|
||||||
|
"rabc_kappa",
|
||||||
|
"rabc_epsilon",
|
||||||
|
"rabc_head_mode",
|
||||||
|
)
|
||||||
|
if not any(key in config for key in legacy_fields):
|
||||||
|
return None
|
||||||
|
|
||||||
|
migrated_config = dict(config)
|
||||||
|
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
||||||
|
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
||||||
|
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
||||||
|
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
||||||
|
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
||||||
|
|
||||||
|
# New configs may already define sample_weighting explicitly. In that case,
|
||||||
|
# legacy fields are ignored after being stripped from the payload.
|
||||||
|
if migrated_config.get("sample_weighting") is None and use_rabc:
|
||||||
|
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
||||||
|
if rabc_progress_path is not None:
|
||||||
|
sample_weighting["progress_path"] = rabc_progress_path
|
||||||
|
if rabc_kappa is not None:
|
||||||
|
sample_weighting["kappa"] = rabc_kappa
|
||||||
|
if rabc_epsilon is not None:
|
||||||
|
sample_weighting["epsilon"] = rabc_epsilon
|
||||||
|
if rabc_head_mode is not None:
|
||||||
|
sample_weighting["head_mode"] = rabc_head_mode
|
||||||
|
migrated_config["sample_weighting"] = sample_weighting
|
||||||
|
|
||||||
|
return migrated_config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineConfig(HubMixin):
|
class TrainPipelineConfig(HubMixin):
|
||||||
dataset: DatasetConfig
|
dataset: DatasetConfig
|
||||||
@@ -218,6 +256,15 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
cli_args = kwargs.pop("cli_args", [])
|
cli_args = kwargs.pop("cli_args", [])
|
||||||
|
if config_file is not None:
|
||||||
|
with open(config_file) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||||
|
if migrated_config is not None:
|
||||||
|
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||||
|
json.dump(migrated_config, f)
|
||||||
|
config_file = f.name
|
||||||
|
|
||||||
with draccus.config_type("json"):
|
with draccus.config_type("json"):
|
||||||
return draccus.parse(cls, config_file, args=cli_args)
|
return draccus.parse(cls, config_file, args=cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
"""Tests for the reward model base classes and registry."""
|
"""Tests for the reward model base classes and registry."""
|
||||||
|
|
||||||
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
@@ -251,6 +252,79 @@ def test_train_pipeline_config_trainable_config_returns_policy_when_set():
|
|||||||
assert cfg.trainable_config.device == "cpu"
|
assert cfg.trainable_config.device == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_pipeline_config_from_pretrained_migrates_legacy_rabc_fields(tmp_path):
|
||||||
|
"""Legacy top-level RA-BC fields should be migrated into ``sample_weighting``."""
|
||||||
|
from lerobot.configs.default import DatasetConfig
|
||||||
|
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||||
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
|
||||||
|
cfg = TrainPipelineConfig(
|
||||||
|
dataset=DatasetConfig(repo_id="user/repo"),
|
||||||
|
policy=DiffusionConfig(device="cpu"),
|
||||||
|
)
|
||||||
|
cfg._save_pretrained(tmp_path)
|
||||||
|
|
||||||
|
config_path = tmp_path / TRAIN_CONFIG_NAME
|
||||||
|
with open(config_path) as f:
|
||||||
|
payload = json.load(f)
|
||||||
|
|
||||||
|
payload.pop("sample_weighting", None)
|
||||||
|
payload.update(
|
||||||
|
{
|
||||||
|
"use_rabc": True,
|
||||||
|
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
|
||||||
|
"rabc_kappa": 0.05,
|
||||||
|
"rabc_epsilon": 1e-5,
|
||||||
|
"rabc_head_mode": "dense",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(payload, f)
|
||||||
|
|
||||||
|
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
|
||||||
|
|
||||||
|
assert loaded.sample_weighting is not None
|
||||||
|
assert loaded.sample_weighting.type == "rabc"
|
||||||
|
assert loaded.sample_weighting.progress_path == "hf://datasets/user/repo/sarm_progress.parquet"
|
||||||
|
assert loaded.sample_weighting.kappa == 0.05
|
||||||
|
assert loaded.sample_weighting.epsilon == 1e-5
|
||||||
|
assert loaded.sample_weighting.head_mode == "dense"
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_pipeline_config_from_pretrained_strips_legacy_rabc_when_disabled(tmp_path):
|
||||||
|
"""Legacy RA-BC fields should be ignored when ``use_rabc`` was false."""
|
||||||
|
from lerobot.configs.default import DatasetConfig
|
||||||
|
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||||
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
|
||||||
|
cfg = TrainPipelineConfig(
|
||||||
|
dataset=DatasetConfig(repo_id="user/repo"),
|
||||||
|
policy=DiffusionConfig(device="cpu"),
|
||||||
|
)
|
||||||
|
cfg._save_pretrained(tmp_path)
|
||||||
|
|
||||||
|
config_path = tmp_path / TRAIN_CONFIG_NAME
|
||||||
|
with open(config_path) as f:
|
||||||
|
payload = json.load(f)
|
||||||
|
|
||||||
|
payload.pop("sample_weighting", None)
|
||||||
|
payload.update(
|
||||||
|
{
|
||||||
|
"use_rabc": False,
|
||||||
|
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
|
||||||
|
"rabc_kappa": 0.05,
|
||||||
|
"rabc_epsilon": 1e-5,
|
||||||
|
"rabc_head_mode": "dense",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(payload, f)
|
||||||
|
|
||||||
|
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
|
||||||
|
|
||||||
|
assert loaded.sample_weighting is None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
|
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
|
||||||
# We test the generation side (offline) fully, and the upload side with HfApi
|
# We test the generation side (offline) fully, and the upload side with HfApi
|
||||||
|
|||||||
Reference in New Issue
Block a user