fix(train): migrate legacy RA-BC fields in train config loading (#3480)

This commit is contained in:
Khalil Meftah
2026-04-29 16:17:00 +02:00
committed by GitHub
parent 2236bbe7a3
commit cd6b43ea7a
2 changed files with 121 additions and 0 deletions
+47
View File
@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
import builtins import builtins
import datetime as dt import datetime as dt
import json
import os import os
import tempfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -35,6 +37,42 @@ from .rewards import RewardModelConfig
TRAIN_CONFIG_NAME = "train_config.json" TRAIN_CONFIG_NAME = "train_config.json"
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
legacy_fields = (
"use_rabc",
"rabc_progress_path",
"rabc_kappa",
"rabc_epsilon",
"rabc_head_mode",
)
if not any(key in config for key in legacy_fields):
return None
migrated_config = dict(config)
use_rabc = bool(migrated_config.pop("use_rabc", False))
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
rabc_kappa = migrated_config.pop("rabc_kappa", None)
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
# New configs may already define sample_weighting explicitly. In that case,
# legacy fields are ignored after being stripped from the payload.
if migrated_config.get("sample_weighting") is None and use_rabc:
sample_weighting: dict[str, Any] = {"type": "rabc"}
if rabc_progress_path is not None:
sample_weighting["progress_path"] = rabc_progress_path
if rabc_kappa is not None:
sample_weighting["kappa"] = rabc_kappa
if rabc_epsilon is not None:
sample_weighting["epsilon"] = rabc_epsilon
if rabc_head_mode is not None:
sample_weighting["head_mode"] = rabc_head_mode
migrated_config["sample_weighting"] = sample_weighting
return migrated_config
@dataclass @dataclass
class TrainPipelineConfig(HubMixin): class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig dataset: DatasetConfig
@@ -218,6 +256,15 @@ class TrainPipelineConfig(HubMixin):
) from e ) from e
cli_args = kwargs.pop("cli_args", []) cli_args = kwargs.pop("cli_args", [])
if config_file is not None:
with open(config_file) as f:
config = json.load(f)
migrated_config = _migrate_legacy_rabc_fields(config)
if migrated_config is not None:
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
json.dump(migrated_config, f)
config_file = f.name
with draccus.config_type("json"): with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args) return draccus.parse(cls, config_file, args=cli_args)
+74
View File
@@ -14,6 +14,7 @@
"""Tests for the reward model base classes and registry.""" """Tests for the reward model base classes and registry."""
import json
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
@@ -251,6 +252,79 @@ def test_train_pipeline_config_trainable_config_returns_policy_when_set():
assert cfg.trainable_config.device == "cpu" assert cfg.trainable_config.device == "cpu"
def test_train_pipeline_config_from_pretrained_migrates_legacy_rabc_fields(tmp_path):
"""Legacy top-level RA-BC fields should be migrated into ``sample_weighting``."""
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/repo"),
policy=DiffusionConfig(device="cpu"),
)
cfg._save_pretrained(tmp_path)
config_path = tmp_path / TRAIN_CONFIG_NAME
with open(config_path) as f:
payload = json.load(f)
payload.pop("sample_weighting", None)
payload.update(
{
"use_rabc": True,
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
"rabc_kappa": 0.05,
"rabc_epsilon": 1e-5,
"rabc_head_mode": "dense",
}
)
with open(config_path, "w") as f:
json.dump(payload, f)
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
assert loaded.sample_weighting is not None
assert loaded.sample_weighting.type == "rabc"
assert loaded.sample_weighting.progress_path == "hf://datasets/user/repo/sarm_progress.parquet"
assert loaded.sample_weighting.kappa == 0.05
assert loaded.sample_weighting.epsilon == 1e-5
assert loaded.sample_weighting.head_mode == "dense"
def test_train_pipeline_config_from_pretrained_strips_legacy_rabc_when_disabled(tmp_path):
"""Legacy RA-BC fields should be ignored when ``use_rabc`` was false."""
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/repo"),
policy=DiffusionConfig(device="cpu"),
)
cfg._save_pretrained(tmp_path)
config_path = tmp_path / TRAIN_CONFIG_NAME
with open(config_path) as f:
payload = json.load(f)
payload.pop("sample_weighting", None)
payload.update(
{
"use_rabc": False,
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
"rabc_kappa": 0.05,
"rabc_epsilon": 1e-5,
"rabc_head_mode": "dense",
}
)
with open(config_path, "w") as f:
json.dump(payload, f)
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
assert loaded.sample_weighting is None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card. # PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
# We test the generation side (offline) fully, and the upload side with HfApi # We test the generation side (offline) fully, and the upload side with HfApi