mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 20:57:28 +00:00
feat(train): resume training from a Hub checkpoint
Allow --config_path to be a Hub repo id when resuming, not only a local path. The latest checkpoint under checkpoints/<step>/ is downloaded into a fresh local run dir and resumed from there (optimizer, scheduler, RNG and data order restored as for a local resume). TrainPipelineConfig.from_pretrained falls back to the latest checkpoint's train_config.json when a repo has no root config (an interrupted run that only pushed checkpoints). The download is skipped when dispatching remotely so the executor (local machine or HF Jobs pod) performs it. - add find_latest_hub_checkpoint (utils/hub) and resolve_resume_checkpoint (common/train_utils), the symmetric download counterpart to push_checkpoint_to_hub - unit tests for both helpers and the from_pretrained fallback
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
@@ -36,6 +36,7 @@ from lerobot.utils.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.utils.hub import find_latest_hub_checkpoint
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.random_utils import load_rng_state, save_rng_state
|
||||
|
||||
@@ -316,3 +317,29 @@ def push_checkpoint_to_hub(
|
||||
repo_type="model",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
|
||||
def resolve_resume_checkpoint(repo_id: str, output_dir: Path) -> Path:
|
||||
"""Download the latest checkpoint of a Hub training repo into a local run dir.
|
||||
|
||||
The symmetric counterpart to `push_checkpoint_to_hub`: given a model repo holding
|
||||
`checkpoints/<step>/{pretrained_model,training_state}` subtrees, download the highest-numbered step
|
||||
into `output_dir/checkpoints/<step>/`, recreate the local `last` symlink, and return that local
|
||||
checkpoint dir. Used to resume training from the Hub on a machine (or HF Jobs pod) that does not
|
||||
have the original local run dir.
|
||||
"""
|
||||
latest = find_latest_hub_checkpoint(repo_id)
|
||||
if latest is None:
|
||||
raise FileNotFoundError(
|
||||
f"No checkpoint found in '{repo_id}' under '{CHECKPOINTS_DIR}/'. "
|
||||
"Was the run trained with --save_checkpoint_to_hub?"
|
||||
)
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
allow_patterns=f"{latest}/*",
|
||||
local_dir=str(output_dir),
|
||||
)
|
||||
checkpoint_dir = output_dir / latest
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
return checkpoint_dir
|
||||
|
||||
@@ -26,7 +26,8 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot import envs
|
||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
|
||||
from lerobot.utils.hub import HubMixin, find_latest_hub_checkpoint
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
@@ -83,10 +84,11 @@ class TrainPipelineConfig(HubMixin):
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||
# `dir` is the directory of an existing run with at least one checkpoint in it.
|
||||
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
||||
# regardless of what's provided with the training command at the time of resumption.
|
||||
# Set `resume` to true to resume a previous run. Pass `--config_path` pointing at either a local
|
||||
# checkpoint's train_config.json or a Hub repo id holding `checkpoints/<step>/` subtrees (the
|
||||
# latest checkpoint is downloaded and resumed from). Note that when resuming, the default behavior
|
||||
# is to use the configuration from the checkpoint, regardless of what's provided with the training
|
||||
# command at the time of resumption (CLI `--*` flags still override).
|
||||
resume: bool = False
|
||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
@@ -165,25 +167,44 @@ class TrainPipelineConfig(HubMixin):
|
||||
self._resolve_resume_checkpoint()
|
||||
|
||||
def _resolve_resume_checkpoint(self) -> None:
|
||||
"""Point the trainable config at the checkpoint named by `--config_path`."""
|
||||
"""Point the trainable config at the checkpoint named by `--config_path`.
|
||||
|
||||
`config_path` is either a local path (to a checkpoint's train_config.json or its
|
||||
pretrained_model/ dir) or a Hub repo id. For a Hub repo, the latest checkpoint is downloaded
|
||||
into a fresh local run dir and resumed from there. The download is skipped when dispatching to
|
||||
an HF Job (`job.is_remote`): the pod performs it when it runs the resume locally, and
|
||||
`submit_to_hf` resolves the source repo for the remote command.
|
||||
"""
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
"Resuming from the hub is not supported for now."
|
||||
)
|
||||
if Path(config_path).resolve().exists():
|
||||
policy_dir = Path(config_path).parent
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
elif self.job.is_remote:
|
||||
return
|
||||
else:
|
||||
from lerobot.common.train_utils import resolve_resume_checkpoint
|
||||
|
||||
# `self.output_dir` was loaded from the checkpoint's config and points at the original
|
||||
# run's (now-absent) local dir. Resume into a fresh local dir instead, unless the user
|
||||
# passed --output_dir explicitly.
|
||||
cli_output_dir = parser.parse_arg("output_dir")
|
||||
if cli_output_dir:
|
||||
self.output_dir = Path(cli_output_dir)
|
||||
else:
|
||||
now = dt.datetime.now()
|
||||
self.output_dir = Path("outputs/train") / f"{now:%Y-%m-%d}/{now:%H-%M-%S}_resume"
|
||||
self.checkpoint_path = resolve_resume_checkpoint(config_path, self.output_dir)
|
||||
policy_dir = self.checkpoint_path / PRETRAINED_MODEL_DIR
|
||||
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
if self.reward_model is not None:
|
||||
self.reward_model.pretrained_path = str(policy_dir)
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
def validate(self) -> None:
|
||||
self._resolve_pretrained_from_cli()
|
||||
@@ -275,22 +296,30 @@ class TrainPipelineConfig(HubMixin):
|
||||
elif Path(model_id).is_file():
|
||||
config_file = model_id
|
||||
else:
|
||||
dl_kwargs = {
|
||||
"repo_id": model_id,
|
||||
"revision": revision,
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"resume_download": resume_download,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=TRAIN_CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
config_file = hf_hub_download(filename=TRAIN_CONFIG_NAME, **dl_kwargs)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
# No root train_config.json: this is a repo of periodic checkpoints from an
|
||||
# interrupted run. Fall back to the latest checkpoint's config so the run can be
|
||||
# resumed straight from the repo with `--config_path=<repo>`.
|
||||
latest = find_latest_hub_checkpoint(model_id, token=token, revision=revision)
|
||||
if latest is None:
|
||||
raise FileNotFoundError(
|
||||
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
config_file = hf_hub_download(
|
||||
filename=f"{latest}/{PRETRAINED_MODEL_DIR}/{TRAIN_CONFIG_NAME}", **dl_kwargs
|
||||
)
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
|
||||
|
||||
@@ -20,9 +20,33 @@ from typing import Any, TypeVar
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from lerobot.utils.constants import CHECKPOINTS_DIR
|
||||
|
||||
T = TypeVar("T", bound="HubMixin")
|
||||
|
||||
|
||||
def find_latest_hub_checkpoint(
|
||||
repo_id: str,
|
||||
*,
|
||||
token: str | bool | None = None,
|
||||
revision: str | None = None,
|
||||
) -> str | None:
|
||||
"""Repo-relative path of the most recent checkpoint in a training repo.
|
||||
|
||||
Training runs push checkpoints to ``checkpoints/<step>/`` (see
|
||||
``push_checkpoint_to_hub``). This lists those step dirs and returns
|
||||
``checkpoints/<highest-step>``, or ``None`` if the repo has no checkpoints.
|
||||
"""
|
||||
files = HfApi().list_repo_files(repo_id=repo_id, repo_type="model", revision=revision, token=token)
|
||||
prefix = f"{CHECKPOINTS_DIR}/"
|
||||
steps = {
|
||||
name for f in files if f.startswith(prefix) and (name := f[len(prefix) :].split("/", 1)[0]).isdigit()
|
||||
}
|
||||
if not steps:
|
||||
return None
|
||||
return f"{CHECKPOINTS_DIR}/{max(steps, key=int)}"
|
||||
|
||||
|
||||
class HubMixin:
|
||||
"""
|
||||
A Mixin containing the functionality to push an object to the hub.
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
import lerobot.configs.train as tc
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
class _FakeHTTPError(tc.HfHubHTTPError):
|
||||
"""HfHubHTTPError that can be raised without a real HTTP response object."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_from_pretrained_falls_back_to_latest_checkpoint_config(tmp_path, monkeypatch):
|
||||
"""A Hub repo with no root train_config.json (an interrupted run that only pushed
|
||||
checkpoints/) resolves via the latest checkpoint's config."""
|
||||
# A real train_config.json written by save_pretrained, to be returned by the fallback.
|
||||
parsed = tc.draccus.parse(TrainPipelineConfig, args=["--dataset.repo_id", "u/d"])
|
||||
cfg_file = tmp_path / "train_config.json"
|
||||
parsed._save_pretrained(tmp_path)
|
||||
assert cfg_file.is_file()
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_hf_hub_download(filename=None, **kwargs):
|
||||
calls.append(filename)
|
||||
if filename == "train_config.json":
|
||||
raise _FakeHTTPError() # no root config
|
||||
if filename == "checkpoints/000010/pretrained_model/train_config.json":
|
||||
return str(cfg_file)
|
||||
raise AssertionError(f"unexpected filename {filename}")
|
||||
|
||||
monkeypatch.setattr(tc, "hf_hub_download", fake_hf_hub_download)
|
||||
monkeypatch.setattr(
|
||||
tc, "find_latest_hub_checkpoint", lambda repo_id, token=None, revision=None: "checkpoints/000010"
|
||||
)
|
||||
|
||||
loaded = TrainPipelineConfig.from_pretrained("user/interrupted-run")
|
||||
assert loaded.dataset.repo_id == "u/d"
|
||||
# Tried the root config first, then fell back to the latest checkpoint's config.
|
||||
assert calls == ["train_config.json", "checkpoints/000010/pretrained_model/train_config.json"]
|
||||
|
||||
|
||||
def test_from_pretrained_raises_when_no_root_config_and_no_checkpoints(monkeypatch):
|
||||
"""No root config AND no checkpoints → a clear FileNotFoundError, not the raw HTTP error."""
|
||||
|
||||
def fake_hf_hub_download(filename=None, **kwargs):
|
||||
raise _FakeHTTPError()
|
||||
|
||||
monkeypatch.setattr(tc, "hf_hub_download", fake_hf_hub_download)
|
||||
monkeypatch.setattr(tc, "find_latest_hub_checkpoint", lambda repo_id, token=None, revision=None: None)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="train_config.json not found"):
|
||||
TrainPipelineConfig.from_pretrained("user/empty-repo")
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from lerobot.utils.hub import find_latest_hub_checkpoint
|
||||
|
||||
|
||||
def _patch_list_files(monkeypatch, files):
|
||||
api = MagicMock()
|
||||
api.list_repo_files.return_value = files
|
||||
# HfApi is imported into lerobot.utils.hub at module load, so patch it there.
|
||||
monkeypatch.setattr("lerobot.utils.hub.HfApi", lambda *a, **k: api)
|
||||
return api
|
||||
|
||||
|
||||
def test_find_latest_hub_checkpoint_picks_highest_step(monkeypatch):
|
||||
_patch_list_files(
|
||||
monkeypatch,
|
||||
[
|
||||
"README.md",
|
||||
"checkpoints/000500/pretrained_model/model.safetensors",
|
||||
"checkpoints/000500/training_state/training_step.json",
|
||||
"checkpoints/020000/pretrained_model/model.safetensors",
|
||||
"checkpoints/001000/training_state/training_step.json",
|
||||
],
|
||||
)
|
||||
# Numeric max, not lexicographic — "020000" beats "001000"/"000500".
|
||||
assert find_latest_hub_checkpoint("u/run") == "checkpoints/020000"
|
||||
|
||||
|
||||
def test_find_latest_hub_checkpoint_ignores_non_step_entries(monkeypatch):
|
||||
_patch_list_files(
|
||||
monkeypatch,
|
||||
["checkpoints/last/pretrained_model/model.safetensors", "config.json"],
|
||||
)
|
||||
# "last" (a symlink target name) is not a numeric step → no resolvable checkpoint.
|
||||
assert find_latest_hub_checkpoint("u/run") is None
|
||||
|
||||
|
||||
def test_find_latest_hub_checkpoint_none_when_no_checkpoints(monkeypatch):
|
||||
_patch_list_files(monkeypatch, ["config.json", "model.safetensors"])
|
||||
assert find_latest_hub_checkpoint("u/run") is None
|
||||
@@ -17,6 +17,8 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
@@ -188,3 +190,36 @@ def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, mon
|
||||
push_checkpoint_to_hub(ckpt, "user/run")
|
||||
api.create_repo.assert_called_once()
|
||||
assert api.create_repo.call_args.kwargs["private"] is None
|
||||
|
||||
|
||||
def test_resolve_resume_checkpoint_downloads_latest_and_links(tmp_path, monkeypatch):
|
||||
from lerobot.common import train_utils
|
||||
|
||||
out = tmp_path / "run"
|
||||
|
||||
def fake_snapshot_download(repo_id, repo_type, allow_patterns, local_dir):
|
||||
# Mimic the Hub layout the real download materializes locally.
|
||||
assert allow_patterns == "checkpoints/020000/*"
|
||||
(Path(local_dir) / "checkpoints" / "020000" / "pretrained_model").mkdir(parents=True)
|
||||
return local_dir
|
||||
|
||||
monkeypatch.setattr("lerobot.common.train_utils.snapshot_download", fake_snapshot_download)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.common.train_utils.find_latest_hub_checkpoint", lambda repo_id: "checkpoints/020000"
|
||||
)
|
||||
|
||||
checkpoint_dir = train_utils.resolve_resume_checkpoint("u/run", out)
|
||||
|
||||
assert checkpoint_dir == out / CHECKPOINTS_DIR / "020000"
|
||||
last = out / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK
|
||||
assert last.is_symlink()
|
||||
# `last` points at the downloaded step dir.
|
||||
assert (last.parent / last.readlink()).resolve() == checkpoint_dir.resolve()
|
||||
|
||||
|
||||
def test_resolve_resume_checkpoint_raises_without_checkpoints(tmp_path, monkeypatch):
|
||||
from lerobot.common import train_utils
|
||||
|
||||
monkeypatch.setattr("lerobot.common.train_utils.find_latest_hub_checkpoint", lambda repo_id: None)
|
||||
with pytest.raises(FileNotFoundError, match="No checkpoint"):
|
||||
train_utils.resolve_resume_checkpoint("u/run", tmp_path / "run")
|
||||
|
||||
Reference in New Issue
Block a user