mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
0e6114ac36
* fix(train): restrict legacy RA-BC migration to JSON checkpoints only
_migrate_legacy_rabc_fields was called for all config files, causing
json.load to raise DecodeError when a YAML/TOML config was passed to
lerobot-train for a new training run. Guard the block with an
.endswith(".json") check so migration only runs when resuming from
a JSON checkpoint.
279 lines
12 KiB
Python
279 lines
12 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import builtins
|
|
import datetime as dt
|
|
import json
|
|
import os
|
|
import tempfile
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import draccus
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.errors import HfHubHTTPError
|
|
|
|
from lerobot import envs
|
|
from lerobot.configs import parser
|
|
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
|
from lerobot.utils.hub import HubMixin
|
|
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
|
|
|
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
|
from .policies import PreTrainedConfig
|
|
from .rewards import RewardModelConfig
|
|
|
|
TRAIN_CONFIG_NAME = "train_config.json"
|
|
|
|
|
|
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
|
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
|
legacy_fields = (
|
|
"use_rabc",
|
|
"rabc_progress_path",
|
|
"rabc_kappa",
|
|
"rabc_epsilon",
|
|
"rabc_head_mode",
|
|
)
|
|
if not any(key in config for key in legacy_fields):
|
|
return None
|
|
|
|
migrated_config = dict(config)
|
|
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
|
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
|
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
|
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
|
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
|
|
|
# New configs may already define sample_weighting explicitly. In that case,
|
|
# legacy fields are ignored after being stripped from the payload.
|
|
if migrated_config.get("sample_weighting") is None and use_rabc:
|
|
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
|
if rabc_progress_path is not None:
|
|
sample_weighting["progress_path"] = rabc_progress_path
|
|
if rabc_kappa is not None:
|
|
sample_weighting["kappa"] = rabc_kappa
|
|
if rabc_epsilon is not None:
|
|
sample_weighting["epsilon"] = rabc_epsilon
|
|
if rabc_head_mode is not None:
|
|
sample_weighting["head_mode"] = rabc_head_mode
|
|
migrated_config["sample_weighting"] = sample_weighting
|
|
|
|
return migrated_config
|
|
|
|
|
|
@dataclass
|
|
class TrainPipelineConfig(HubMixin):
|
|
dataset: DatasetConfig
|
|
env: envs.EnvConfig | None = None
|
|
policy: PreTrainedConfig | None = None
|
|
reward_model: RewardModelConfig | None = None
|
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
|
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
|
output_dir: Path | None = None
|
|
job_name: str | None = None
|
|
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
|
# `dir` is the directory of an existing run with at least one checkpoint in it.
|
|
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
|
# regardless of what's provided with the training command at the time of resumption.
|
|
resume: bool = False
|
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
|
# AND for the evaluation environments.
|
|
seed: int | None = 1000
|
|
# Set to True to use deterministic cuDNN algorithms for reproducibility.
|
|
# This disables cudnn.benchmark and may reduce training speed by ~10-20 percent.
|
|
cudnn_deterministic: bool = False
|
|
# Number of workers for the dataloader.
|
|
num_workers: int = 4
|
|
batch_size: int = 8
|
|
prefetch_factor: int = 4
|
|
persistent_workers: bool = True
|
|
steps: int = 100_000
|
|
eval_freq: int = 20_000
|
|
log_freq: int = 200
|
|
tolerance_s: float = 1e-4
|
|
save_checkpoint: bool = True
|
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
|
save_freq: int = 20_000
|
|
use_policy_training_preset: bool = True
|
|
optimizer: OptimizerConfig | None = None
|
|
scheduler: LRSchedulerConfig | None = None
|
|
eval: EvalConfig = field(default_factory=EvalConfig)
|
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
|
peft: PeftConfig | None = None
|
|
|
|
# Sample weighting configuration (e.g., for RA-BC training)
|
|
sample_weighting: SampleWeightingConfig | None = None
|
|
|
|
# Rename map for the observation to override the image and state keys
|
|
rename_map: dict[str, str] = field(default_factory=dict)
|
|
checkpoint_path: Path | None = field(init=False, default=None)
|
|
|
|
@property
|
|
def is_reward_model_training(self) -> bool:
|
|
"""True when the config targets a reward model rather than a policy."""
|
|
return self.reward_model is not None
|
|
|
|
@property
|
|
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
|
|
"""Return whichever config (policy or reward_model) is active."""
|
|
if self.is_reward_model_training:
|
|
return self.reward_model # type: ignore[return-value]
|
|
return self.policy # type: ignore[return-value]
|
|
|
|
def validate(self) -> None:
|
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
|
policy_path = parser.get_path_arg("policy")
|
|
reward_model_path = parser.get_path_arg("reward_model")
|
|
|
|
if reward_model_path:
|
|
cli_overrides = parser.get_cli_overrides("reward_model")
|
|
self.reward_model = RewardModelConfig.from_pretrained(
|
|
reward_model_path, cli_overrides=cli_overrides
|
|
)
|
|
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
|
elif policy_path:
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = Path(policy_path)
|
|
elif self.resume:
|
|
config_path = parser.parse_arg("config_path")
|
|
if not config_path:
|
|
raise ValueError(
|
|
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
|
)
|
|
|
|
if not Path(config_path).resolve().exists():
|
|
raise NotADirectoryError(
|
|
f"{config_path=} is expected to be a local path. "
|
|
"Resuming from the hub is not supported for now."
|
|
)
|
|
|
|
policy_dir = Path(config_path).parent
|
|
if self.policy is not None:
|
|
self.policy.pretrained_path = policy_dir
|
|
if self.reward_model is not None:
|
|
self.reward_model.pretrained_path = str(policy_dir)
|
|
self.checkpoint_path = policy_dir.parent
|
|
|
|
if self.policy is None and self.reward_model is None:
|
|
raise ValueError(
|
|
"Neither policy nor reward_model is configured. "
|
|
"Please specify one with `--policy.path` or `--reward_model.path`."
|
|
)
|
|
|
|
active_cfg = self.trainable_config
|
|
if not self.job_name:
|
|
if self.env is None:
|
|
self.job_name = f"{active_cfg.type}"
|
|
else:
|
|
self.job_name = f"{self.env.type}_{active_cfg.type}"
|
|
|
|
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
|
raise FileExistsError(
|
|
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
|
|
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
|
)
|
|
elif not self.output_dir:
|
|
now = dt.datetime.now()
|
|
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
|
self.output_dir = Path("outputs/train") / train_dir
|
|
|
|
if isinstance(self.dataset.repo_id, list):
|
|
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
|
|
|
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
|
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
|
elif self.use_policy_training_preset and not self.resume:
|
|
self.optimizer = active_cfg.get_optimizer_preset()
|
|
self.scheduler = active_cfg.get_scheduler_preset()
|
|
|
|
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
|
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""Keys for draccus pretrained-path loading."""
|
|
return ["policy", "reward_model"]
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
|
|
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
|
draccus.dump(self, f, indent=4)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls: builtins.type["TrainPipelineConfig"],
|
|
pretrained_name_or_path: str | Path,
|
|
*,
|
|
force_download: bool = False,
|
|
resume_download: bool | None = None,
|
|
proxies: dict[Any, Any] | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
**kwargs: Any,
|
|
) -> "TrainPipelineConfig":
|
|
model_id = str(pretrained_name_or_path)
|
|
config_file: str | None = None
|
|
if Path(model_id).is_dir():
|
|
if TRAIN_CONFIG_NAME in os.listdir(model_id):
|
|
config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
|
|
else:
|
|
print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
|
elif Path(model_id).is_file():
|
|
config_file = model_id
|
|
else:
|
|
try:
|
|
config_file = hf_hub_download(
|
|
repo_id=model_id,
|
|
filename=TRAIN_CONFIG_NAME,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
token=token,
|
|
local_files_only=local_files_only,
|
|
)
|
|
except HfHubHTTPError as e:
|
|
raise FileNotFoundError(
|
|
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
|
) from e
|
|
|
|
cli_args = kwargs.pop("cli_args", [])
|
|
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
|
|
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
|
|
if config_file is not None and config_file.endswith(".json"):
|
|
with open(config_file) as f:
|
|
config = json.load(f)
|
|
migrated_config = _migrate_legacy_rabc_fields(config)
|
|
if migrated_config is not None:
|
|
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
|
json.dump(migrated_config, f)
|
|
config_file = f.name
|
|
|
|
with draccus.config_type("json"):
|
|
return draccus.parse(cls, config_file, args=cli_args)
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
|
# NOTE: In RL, we don't need an offline dataset
|
|
# TODO: Make `TrainPipelineConfig.dataset` optional
|
|
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|