fix(train): migrate legacy RA-BC fields in train config loading (#3480)

This commit is contained in:
Khalil Meftah
2026-04-29 16:17:00 +02:00
committed by GitHub
parent 2236bbe7a3
commit cd6b43ea7a
2 changed files with 121 additions and 0 deletions
+74
View File
@@ -14,6 +14,7 @@
"""Tests for the reward model base classes and registry."""
import json
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
@@ -251,6 +252,79 @@ def test_train_pipeline_config_trainable_config_returns_policy_when_set():
assert cfg.trainable_config.device == "cpu"
def test_train_pipeline_config_from_pretrained_migrates_legacy_rabc_fields(tmp_path):
"""Legacy top-level RA-BC fields should be migrated into ``sample_weighting``."""
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/repo"),
policy=DiffusionConfig(device="cpu"),
)
cfg._save_pretrained(tmp_path)
config_path = tmp_path / TRAIN_CONFIG_NAME
with open(config_path) as f:
payload = json.load(f)
payload.pop("sample_weighting", None)
payload.update(
{
"use_rabc": True,
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
"rabc_kappa": 0.05,
"rabc_epsilon": 1e-5,
"rabc_head_mode": "dense",
}
)
with open(config_path, "w") as f:
json.dump(payload, f)
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
assert loaded.sample_weighting is not None
assert loaded.sample_weighting.type == "rabc"
assert loaded.sample_weighting.progress_path == "hf://datasets/user/repo/sarm_progress.parquet"
assert loaded.sample_weighting.kappa == 0.05
assert loaded.sample_weighting.epsilon == 1e-5
assert loaded.sample_weighting.head_mode == "dense"
def test_train_pipeline_config_from_pretrained_strips_legacy_rabc_when_disabled(tmp_path):
"""Legacy RA-BC fields should be ignored when ``use_rabc`` was false."""
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
cfg = TrainPipelineConfig(
dataset=DatasetConfig(repo_id="user/repo"),
policy=DiffusionConfig(device="cpu"),
)
cfg._save_pretrained(tmp_path)
config_path = tmp_path / TRAIN_CONFIG_NAME
with open(config_path) as f:
payload = json.load(f)
payload.pop("sample_weighting", None)
payload.update(
{
"use_rabc": False,
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
"rabc_kappa": 0.05,
"rabc_epsilon": 1e-5,
"rabc_head_mode": "dense",
}
)
with open(config_path, "w") as f:
json.dump(payload, f)
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
assert loaded.sample_weighting is None
# ---------------------------------------------------------------------------
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
# We test the generation side (offline) fully, and the upload side with HfApi