mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
fix(train): migrate legacy RA-BC fields in train config loading (#3480)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
"""Tests for the reward model base classes and registry."""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
@@ -251,6 +252,79 @@ def test_train_pipeline_config_trainable_config_returns_policy_when_set():
|
||||
assert cfg.trainable_config.device == "cpu"
|
||||
|
||||
|
||||
def test_train_pipeline_config_from_pretrained_migrates_legacy_rabc_fields(tmp_path):
|
||||
"""Legacy top-level RA-BC fields should be migrated into ``sample_weighting``."""
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
|
||||
cfg = TrainPipelineConfig(
|
||||
dataset=DatasetConfig(repo_id="user/repo"),
|
||||
policy=DiffusionConfig(device="cpu"),
|
||||
)
|
||||
cfg._save_pretrained(tmp_path)
|
||||
|
||||
config_path = tmp_path / TRAIN_CONFIG_NAME
|
||||
with open(config_path) as f:
|
||||
payload = json.load(f)
|
||||
|
||||
payload.pop("sample_weighting", None)
|
||||
payload.update(
|
||||
{
|
||||
"use_rabc": True,
|
||||
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
|
||||
"rabc_kappa": 0.05,
|
||||
"rabc_epsilon": 1e-5,
|
||||
"rabc_head_mode": "dense",
|
||||
}
|
||||
)
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(payload, f)
|
||||
|
||||
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
|
||||
|
||||
assert loaded.sample_weighting is not None
|
||||
assert loaded.sample_weighting.type == "rabc"
|
||||
assert loaded.sample_weighting.progress_path == "hf://datasets/user/repo/sarm_progress.parquet"
|
||||
assert loaded.sample_weighting.kappa == 0.05
|
||||
assert loaded.sample_weighting.epsilon == 1e-5
|
||||
assert loaded.sample_weighting.head_mode == "dense"
|
||||
|
||||
|
||||
def test_train_pipeline_config_from_pretrained_strips_legacy_rabc_when_disabled(tmp_path):
|
||||
"""Legacy RA-BC fields should be ignored when ``use_rabc`` was false."""
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
|
||||
cfg = TrainPipelineConfig(
|
||||
dataset=DatasetConfig(repo_id="user/repo"),
|
||||
policy=DiffusionConfig(device="cpu"),
|
||||
)
|
||||
cfg._save_pretrained(tmp_path)
|
||||
|
||||
config_path = tmp_path / TRAIN_CONFIG_NAME
|
||||
with open(config_path) as f:
|
||||
payload = json.load(f)
|
||||
|
||||
payload.pop("sample_weighting", None)
|
||||
payload.update(
|
||||
{
|
||||
"use_rabc": False,
|
||||
"rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet",
|
||||
"rabc_kappa": 0.05,
|
||||
"rabc_epsilon": 1e-5,
|
||||
"rabc_head_mode": "dense",
|
||||
}
|
||||
)
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(payload, f)
|
||||
|
||||
loaded = TrainPipelineConfig.from_pretrained(tmp_path)
|
||||
|
||||
assert loaded.sample_weighting is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card.
|
||||
# We test the generation side (offline) fully, and the upload side with HfApi
|
||||
|
||||
Reference in New Issue
Block a user