feat(train): resume training from a Hub checkpoint

Allow --config_path to be a Hub repo id when resuming, not only a local path.
The latest checkpoint under checkpoints/<step>/ is downloaded into a fresh local
run dir and resumed from there (optimizer, scheduler, RNG and data order
restored as for a local resume). TrainPipelineConfig.from_pretrained falls back
to the latest checkpoint's train_config.json when a repo has no root config
(an interrupted run that only pushed checkpoints). The download is skipped when
dispatching remotely so the executor (local machine or HF Jobs pod) performs it.

- add find_latest_hub_checkpoint (utils/hub) and resolve_resume_checkpoint
  (common/train_utils), the symmetric download counterpart to
  push_checkpoint_to_hub
- unit tests for both helpers and the from_pretrained fallback
This commit is contained in:
Nicolas Rabault
2026-06-24 10:15:55 +02:00
parent 955b172585
commit 838ab9e234
6 changed files with 265 additions and 28 deletions
+28 -1
View File
@@ -15,7 +15,7 @@
# limitations under the License.
from pathlib import Path
from huggingface_hub import HfApi
from huggingface_hub import HfApi, snapshot_download
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
@@ -36,6 +36,7 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.utils.hub import find_latest_hub_checkpoint
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.random_utils import load_rng_state, save_rng_state
@@ -316,3 +317,29 @@ def push_checkpoint_to_hub(
repo_type="model",
exist_ok=True,
)
def resolve_resume_checkpoint(repo_id: str, output_dir: Path) -> Path:
"""Download the latest checkpoint of a Hub training repo into a local run dir.
The symmetric counterpart to `push_checkpoint_to_hub`: given a model repo holding
`checkpoints/<step>/{pretrained_model,training_state}` subtrees, download the highest-numbered step
into `output_dir/checkpoints/<step>/`, recreate the local `last` symlink, and return that local
checkpoint dir. Used to resume training from the Hub on a machine (or HF Jobs pod) that does not
have the original local run dir.
"""
latest = find_latest_hub_checkpoint(repo_id)
if latest is None:
raise FileNotFoundError(
f"No checkpoint found in '{repo_id}' under '{CHECKPOINTS_DIR}/'. "
"Was the run trained with --save_checkpoint_to_hub?"
)
snapshot_download(
repo_id=repo_id,
repo_type="model",
allow_patterns=f"{latest}/*",
local_dir=str(output_dir),
)
checkpoint_dir = output_dir / latest
update_last_checkpoint(checkpoint_dir)
return checkpoint_dir
+56 -27
View File
@@ -26,7 +26,8 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.utils.hub import HubMixin
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
from lerobot.utils.hub import HubMixin, find_latest_hub_checkpoint
from lerobot.utils.sample_weighting import SampleWeightingConfig
from . import parser
@@ -83,10 +84,11 @@ class TrainPipelineConfig(HubMixin):
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None
job_name: str | None = None
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
# `dir` is the directory of an existing run with at least one checkpoint in it.
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
# Set `resume` to true to resume a previous run. Pass `--config_path` pointing at either a local
# checkpoint's train_config.json or a Hub repo id holding `checkpoints/<step>/` subtrees (the
# latest checkpoint is downloaded and resumed from). Note that when resuming, the default behavior
# is to use the configuration from the checkpoint, regardless of what's provided with the training
# command at the time of resumption (CLI `--*` flags still override).
resume: bool = False
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
@@ -165,25 +167,44 @@ class TrainPipelineConfig(HubMixin):
self._resolve_resume_checkpoint()
def _resolve_resume_checkpoint(self) -> None:
"""Point the trainable config at the checkpoint named by `--config_path`."""
"""Point the trainable config at the checkpoint named by `--config_path`.
`config_path` is either a local path (to a checkpoint's train_config.json or its
pretrained_model/ dir) or a Hub repo id. For a Hub repo, the latest checkpoint is downloaded
into a fresh local run dir and resumed from there. The download is skipped when dispatching to
an HF Job (`job.is_remote`): the pod performs it when it runs the resume locally, and
`submit_to_hf` resolves the source repo for the remote command.
"""
config_path = parser.parse_arg("config_path")
if not config_path:
raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
if Path(config_path).resolve().exists():
policy_dir = Path(config_path).parent
self.checkpoint_path = policy_dir.parent
elif self.job.is_remote:
return
else:
from lerobot.common.train_utils import resolve_resume_checkpoint
# `self.output_dir` was loaded from the checkpoint's config and points at the original
# run's (now-absent) local dir. Resume into a fresh local dir instead, unless the user
# passed --output_dir explicitly.
cli_output_dir = parser.parse_arg("output_dir")
if cli_output_dir:
self.output_dir = Path(cli_output_dir)
else:
now = dt.datetime.now()
self.output_dir = Path("outputs/train") / f"{now:%Y-%m-%d}/{now:%H-%M-%S}_resume"
self.checkpoint_path = resolve_resume_checkpoint(config_path, self.output_dir)
policy_dir = self.checkpoint_path / PRETRAINED_MODEL_DIR
policy_dir = Path(config_path).parent
if self.policy is not None:
self.policy.pretrained_path = policy_dir
if self.reward_model is not None:
self.reward_model.pretrained_path = str(policy_dir)
self.checkpoint_path = policy_dir.parent
def validate(self) -> None:
self._resolve_pretrained_from_cli()
@@ -275,22 +296,30 @@ class TrainPipelineConfig(HubMixin):
elif Path(model_id).is_file():
config_file = model_id
else:
dl_kwargs = {
"repo_id": model_id,
"revision": revision,
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"token": token,
"local_files_only": local_files_only,
}
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=TRAIN_CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
config_file = hf_hub_download(filename=TRAIN_CONFIG_NAME, **dl_kwargs)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
# No root train_config.json: this is a repo of periodic checkpoints from an
# interrupted run. Fall back to the latest checkpoint's config so the run can be
# resumed straight from the repo with `--config_path=<repo>`.
latest = find_latest_hub_checkpoint(model_id, token=token, revision=revision)
if latest is None:
raise FileNotFoundError(
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
config_file = hf_hub_download(
filename=f"{latest}/{PRETRAINED_MODEL_DIR}/{TRAIN_CONFIG_NAME}", **dl_kwargs
)
cli_args = kwargs.pop("cli_args", [])
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
+24
View File
@@ -20,9 +20,33 @@ from typing import Any, TypeVar
from huggingface_hub import HfApi
from huggingface_hub.utils import validate_hf_hub_args
from lerobot.utils.constants import CHECKPOINTS_DIR
T = TypeVar("T", bound="HubMixin")
def find_latest_hub_checkpoint(
repo_id: str,
*,
token: str | bool | None = None,
revision: str | None = None,
) -> str | None:
"""Repo-relative path of the most recent checkpoint in a training repo.
Training runs push checkpoints to ``checkpoints/<step>/`` (see
``push_checkpoint_to_hub``). This lists those step dirs and returns
``checkpoints/<highest-step>``, or ``None`` if the repo has no checkpoints.
"""
files = HfApi().list_repo_files(repo_id=repo_id, repo_type="model", revision=revision, token=token)
prefix = f"{CHECKPOINTS_DIR}/"
steps = {
name for f in files if f.startswith(prefix) and (name := f[len(prefix) :].split("/", 1)[0]).isdigit()
}
if not steps:
return None
return f"{CHECKPOINTS_DIR}/{max(steps, key=int)}"
class HubMixin:
"""
A Mixin containing the functionality to push an object to the hub.
+68
View File
@@ -0,0 +1,68 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import lerobot.configs.train as tc
from lerobot.configs.train import TrainPipelineConfig
class _FakeHTTPError(tc.HfHubHTTPError):
"""HfHubHTTPError that can be raised without a real HTTP response object."""
def __init__(self):
pass
def test_from_pretrained_falls_back_to_latest_checkpoint_config(tmp_path, monkeypatch):
"""A Hub repo with no root train_config.json (an interrupted run that only pushed
checkpoints/) resolves via the latest checkpoint's config."""
# A real train_config.json written by save_pretrained, to be returned by the fallback.
parsed = tc.draccus.parse(TrainPipelineConfig, args=["--dataset.repo_id", "u/d"])
cfg_file = tmp_path / "train_config.json"
parsed._save_pretrained(tmp_path)
assert cfg_file.is_file()
calls = []
def fake_hf_hub_download(filename=None, **kwargs):
calls.append(filename)
if filename == "train_config.json":
raise _FakeHTTPError() # no root config
if filename == "checkpoints/000010/pretrained_model/train_config.json":
return str(cfg_file)
raise AssertionError(f"unexpected filename {filename}")
monkeypatch.setattr(tc, "hf_hub_download", fake_hf_hub_download)
monkeypatch.setattr(
tc, "find_latest_hub_checkpoint", lambda repo_id, token=None, revision=None: "checkpoints/000010"
)
loaded = TrainPipelineConfig.from_pretrained("user/interrupted-run")
assert loaded.dataset.repo_id == "u/d"
# Tried the root config first, then fell back to the latest checkpoint's config.
assert calls == ["train_config.json", "checkpoints/000010/pretrained_model/train_config.json"]
def test_from_pretrained_raises_when_no_root_config_and_no_checkpoints(monkeypatch):
"""No root config AND no checkpoints → a clear FileNotFoundError, not the raw HTTP error."""
def fake_hf_hub_download(filename=None, **kwargs):
raise _FakeHTTPError()
monkeypatch.setattr(tc, "hf_hub_download", fake_hf_hub_download)
monkeypatch.setattr(tc, "find_latest_hub_checkpoint", lambda repo_id, token=None, revision=None: None)
with pytest.raises(FileNotFoundError, match="train_config.json not found"):
TrainPipelineConfig.from_pretrained("user/empty-repo")
+54
View File
@@ -0,0 +1,54 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
from lerobot.utils.hub import find_latest_hub_checkpoint
def _patch_list_files(monkeypatch, files):
api = MagicMock()
api.list_repo_files.return_value = files
# HfApi is imported into lerobot.utils.hub at module load, so patch it there.
monkeypatch.setattr("lerobot.utils.hub.HfApi", lambda *a, **k: api)
return api
def test_find_latest_hub_checkpoint_picks_highest_step(monkeypatch):
_patch_list_files(
monkeypatch,
[
"README.md",
"checkpoints/000500/pretrained_model/model.safetensors",
"checkpoints/000500/training_state/training_step.json",
"checkpoints/020000/pretrained_model/model.safetensors",
"checkpoints/001000/training_state/training_step.json",
],
)
# Numeric max, not lexicographic — "020000" beats "001000"/"000500".
assert find_latest_hub_checkpoint("u/run") == "checkpoints/020000"
def test_find_latest_hub_checkpoint_ignores_non_step_entries(monkeypatch):
_patch_list_files(
monkeypatch,
["checkpoints/last/pretrained_model/model.safetensors", "config.json"],
)
# "last" (a symlink target name) is not a numeric step → no resolvable checkpoint.
assert find_latest_hub_checkpoint("u/run") is None
def test_find_latest_hub_checkpoint_none_when_no_checkpoints(monkeypatch):
_patch_list_files(monkeypatch, ["config.json", "model.safetensors"])
assert find_latest_hub_checkpoint("u/run") is None
+35
View File
@@ -17,6 +17,8 @@
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
import pytest
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
@@ -188,3 +190,36 @@ def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, mon
push_checkpoint_to_hub(ckpt, "user/run")
api.create_repo.assert_called_once()
assert api.create_repo.call_args.kwargs["private"] is None
def test_resolve_resume_checkpoint_downloads_latest_and_links(tmp_path, monkeypatch):
from lerobot.common import train_utils
out = tmp_path / "run"
def fake_snapshot_download(repo_id, repo_type, allow_patterns, local_dir):
# Mimic the Hub layout the real download materializes locally.
assert allow_patterns == "checkpoints/020000/*"
(Path(local_dir) / "checkpoints" / "020000" / "pretrained_model").mkdir(parents=True)
return local_dir
monkeypatch.setattr("lerobot.common.train_utils.snapshot_download", fake_snapshot_download)
monkeypatch.setattr(
"lerobot.common.train_utils.find_latest_hub_checkpoint", lambda repo_id: "checkpoints/020000"
)
checkpoint_dir = train_utils.resolve_resume_checkpoint("u/run", out)
assert checkpoint_dir == out / CHECKPOINTS_DIR / "020000"
last = out / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK
assert last.is_symlink()
# `last` points at the downloaded step dir.
assert (last.parent / last.readlink()).resolve() == checkpoint_dir.resolve()
def test_resolve_resume_checkpoint_raises_without_checkpoints(tmp_path, monkeypatch):
from lerobot.common import train_utils
monkeypatch.setattr("lerobot.common.train_utils.find_latest_hub_checkpoint", lambda repo_id: None)
with pytest.raises(FileNotFoundError, match="No checkpoint"):
train_utils.resolve_resume_checkpoint("u/run", tmp_path / "run")