Files
lerobot/src/lerobot/configs/train.py
T
masato-ka 0e6114ac36 fix(train): restrict legacy RA-BC migration to JSON checkpoints only (#3490)
* fix(train): restrict legacy RA-BC migration to JSON checkpoints only

_migrate_legacy_rabc_fields was called for all config files, causing
json.load to raise DecodeError when a YAML/TOML config was passed to
lerobot-train for a new training run. Guard the block with an
.endswith(".json") check so migration only runs when resuming from
a JSON checkpoint.
2026-05-08 20:27:01 +02:00

279 lines
12 KiB
Python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import datetime as dt
import json
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs
from lerobot.configs import parser
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .rewards import RewardModelConfig
TRAIN_CONFIG_NAME = "train_config.json"
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
legacy_fields = (
"use_rabc",
"rabc_progress_path",
"rabc_kappa",
"rabc_epsilon",
"rabc_head_mode",
)
if not any(key in config for key in legacy_fields):
return None
migrated_config = dict(config)
use_rabc = bool(migrated_config.pop("use_rabc", False))
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
rabc_kappa = migrated_config.pop("rabc_kappa", None)
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
# New configs may already define sample_weighting explicitly. In that case,
# legacy fields are ignored after being stripped from the payload.
if migrated_config.get("sample_weighting") is None and use_rabc:
sample_weighting: dict[str, Any] = {"type": "rabc"}
if rabc_progress_path is not None:
sample_weighting["progress_path"] = rabc_progress_path
if rabc_kappa is not None:
sample_weighting["kappa"] = rabc_kappa
if rabc_epsilon is not None:
sample_weighting["epsilon"] = rabc_epsilon
if rabc_head_mode is not None:
sample_weighting["head_mode"] = rabc_head_mode
migrated_config["sample_weighting"] = sample_weighting
return migrated_config
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
reward_model: RewardModelConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None
job_name: str | None = None
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
# `dir` is the directory of an existing run with at least one checkpoint in it.
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
resume: bool = False
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: int | None = 1000
# Set to True to use deterministic cuDNN algorithms for reproducibility.
# This disables cudnn.benchmark and may reduce training speed by ~10-20 percent.
cudnn_deterministic: bool = False
# Number of workers for the dataloader.
num_workers: int = 4
batch_size: int = 8
prefetch_factor: int = 4
persistent_workers: bool = True
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
tolerance_s: float = 1e-4
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: int = 20_000
use_policy_training_preset: bool = True
optimizer: OptimizerConfig | None = None
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
peft: PeftConfig | None = None
# Sample weighting configuration (e.g., for RA-BC training)
sample_weighting: SampleWeightingConfig | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
checkpoint_path: Path | None = field(init=False, default=None)
@property
def is_reward_model_training(self) -> bool:
"""True when the config targets a reward model rather than a policy."""
return self.reward_model is not None
@property
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
"""Return whichever config (policy or reward_model) is active."""
if self.is_reward_model_training:
return self.reward_model # type: ignore[return-value]
return self.policy # type: ignore[return-value]
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
reward_model_path = parser.get_path_arg("reward_model")
if reward_model_path:
cli_overrides = parser.get_cli_overrides("reward_model")
self.reward_model = RewardModelConfig.from_pretrained(
reward_model_path, cli_overrides=cli_overrides
)
self.reward_model.pretrained_path = str(Path(reward_model_path))
elif policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = Path(policy_path)
elif self.resume:
config_path = parser.parse_arg("config_path")
if not config_path:
raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
policy_dir = Path(config_path).parent
if self.policy is not None:
self.policy.pretrained_path = policy_dir
if self.reward_model is not None:
self.reward_model.pretrained_path = str(policy_dir)
self.checkpoint_path = policy_dir.parent
if self.policy is None and self.reward_model is None:
raise ValueError(
"Neither policy nor reward_model is configured. "
"Please specify one with `--policy.path` or `--reward_model.path`."
)
active_cfg = self.trainable_config
if not self.job_name:
if self.env is None:
self.job_name = f"{active_cfg.type}"
else:
self.job_name = f"{self.env.type}_{active_cfg.type}"
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
raise FileExistsError(
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
f"Please change your output directory so that {self.output_dir} is not overwritten."
)
elif not self.output_dir:
now = dt.datetime.now()
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
elif self.use_policy_training_preset and not self.resume:
self.optimizer = active_cfg.get_optimizer_preset()
self.scheduler = active_cfg.get_scheduler_preset()
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""Keys for draccus pretrained-path loading."""
return ["policy", "reward_model"]
def to_dict(self) -> dict[str, Any]:
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: builtins.type["TrainPipelineConfig"],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**kwargs: Any,
) -> "TrainPipelineConfig":
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if TRAIN_CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
else:
print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
elif Path(model_id).is_file():
config_file = model_id
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=TRAIN_CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
cli_args = kwargs.pop("cli_args", [])
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
if config_file is not None and config_file.endswith(".json"):
with open(config_file) as f:
config = json.load(f)
migrated_config = _migrate_legacy_rabc_fields(config)
if migrated_config is not None:
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
json.dump(migrated_config, f)
config_file = f.name
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args)
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional