diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 9d6afcad2..b26196f14 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -15,7 +15,7 @@ # limitations under the License. from pathlib import Path -from huggingface_hub import HfApi +from huggingface_hub import HfApi, snapshot_download from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -36,6 +36,7 @@ from lerobot.utils.constants import ( TRAINING_STATE_DIR, TRAINING_STEP, ) +from lerobot.utils.hub import find_latest_hub_checkpoint from lerobot.utils.io_utils import load_json, write_json from lerobot.utils.random_utils import load_rng_state, save_rng_state @@ -316,3 +317,29 @@ def push_checkpoint_to_hub( repo_type="model", exist_ok=True, ) + + +def resolve_resume_checkpoint(repo_id: str, output_dir: Path) -> Path: + """Download the latest checkpoint of a Hub training repo into a local run dir. + + The symmetric counterpart to `push_checkpoint_to_hub`: given a model repo holding + `checkpoints//{pretrained_model,training_state}` subtrees, download the highest-numbered step + into `output_dir/checkpoints//`, recreate the local `last` symlink, and return that local + checkpoint dir. Used to resume training from the Hub on a machine (or HF Jobs pod) that does not + have the original local run dir. + """ + latest = find_latest_hub_checkpoint(repo_id) + if latest is None: + raise FileNotFoundError( + f"No checkpoint found in '{repo_id}' under '{CHECKPOINTS_DIR}/'. " + "Was the run trained with --save_checkpoint_to_hub?" + ) + snapshot_download( + repo_id=repo_id, + repo_type="model", + allow_patterns=f"{latest}/*", + local_dir=str(output_dir), + ) + checkpoint_dir = output_dir / latest + update_last_checkpoint(checkpoint_dir) + return checkpoint_dir diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index f7fe267b1..a350321e6 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -26,7 +26,8 @@ from huggingface_hub.errors import HfHubHTTPError from lerobot import envs from lerobot.optim import LRSchedulerConfig, OptimizerConfig -from lerobot.utils.hub import HubMixin +from lerobot.utils.constants import PRETRAINED_MODEL_DIR +from lerobot.utils.hub import HubMixin, find_latest_hub_checkpoint from lerobot.utils.sample_weighting import SampleWeightingConfig from . import parser @@ -83,10 +84,11 @@ class TrainPipelineConfig(HubMixin): # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. output_dir: Path | None = None job_name: str | None = None - # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure - # `dir` is the directory of an existing run with at least one checkpoint in it. - # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint, - # regardless of what's provided with the training command at the time of resumption. + # Set `resume` to true to resume a previous run. Pass `--config_path` pointing at either a local + # checkpoint's train_config.json or a Hub repo id holding `checkpoints//` subtrees (the + # latest checkpoint is downloaded and resumed from). Note that when resuming, the default behavior + # is to use the configuration from the checkpoint, regardless of what's provided with the training + # command at the time of resumption (CLI `--*` flags still override). resume: bool = False # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. @@ -165,25 +167,44 @@ class TrainPipelineConfig(HubMixin): self._resolve_resume_checkpoint() def _resolve_resume_checkpoint(self) -> None: - """Point the trainable config at the checkpoint named by `--config_path`.""" + """Point the trainable config at the checkpoint named by `--config_path`. + + `config_path` is either a local path (to a checkpoint's train_config.json or its + pretrained_model/ dir) or a Hub repo id. For a Hub repo, the latest checkpoint is downloaded + into a fresh local run dir and resumed from there. The download is skipped when dispatching to + an HF Job (`job.is_remote`): the pod performs it when it runs the resume locally, and + `submit_to_hf` resolves the source repo for the remote command. + """ config_path = parser.parse_arg("config_path") if not config_path: raise ValueError( f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}" ) - if not Path(config_path).resolve().exists(): - raise NotADirectoryError( - f"{config_path=} is expected to be a local path. " - "Resuming from the hub is not supported for now." - ) + if Path(config_path).resolve().exists(): + policy_dir = Path(config_path).parent + self.checkpoint_path = policy_dir.parent + elif self.job.is_remote: + return + else: + from lerobot.common.train_utils import resolve_resume_checkpoint + + # `self.output_dir` was loaded from the checkpoint's config and points at the original + # run's (now-absent) local dir. Resume into a fresh local dir instead, unless the user + # passed --output_dir explicitly. + cli_output_dir = parser.parse_arg("output_dir") + if cli_output_dir: + self.output_dir = Path(cli_output_dir) + else: + now = dt.datetime.now() + self.output_dir = Path("outputs/train") / f"{now:%Y-%m-%d}/{now:%H-%M-%S}_resume" + self.checkpoint_path = resolve_resume_checkpoint(config_path, self.output_dir) + policy_dir = self.checkpoint_path / PRETRAINED_MODEL_DIR - policy_dir = Path(config_path).parent if self.policy is not None: self.policy.pretrained_path = policy_dir if self.reward_model is not None: self.reward_model.pretrained_path = str(policy_dir) - self.checkpoint_path = policy_dir.parent def validate(self) -> None: self._resolve_pretrained_from_cli() @@ -275,22 +296,30 @@ class TrainPipelineConfig(HubMixin): elif Path(model_id).is_file(): config_file = model_id else: + dl_kwargs = { + "repo_id": model_id, + "revision": revision, + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "token": token, + "local_files_only": local_files_only, + } try: - config_file = hf_hub_download( - repo_id=model_id, - filename=TRAIN_CONFIG_NAME, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - token=token, - local_files_only=local_files_only, - ) + config_file = hf_hub_download(filename=TRAIN_CONFIG_NAME, **dl_kwargs) except HfHubHTTPError as e: - raise FileNotFoundError( - f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" - ) from e + # No root train_config.json: this is a repo of periodic checkpoints from an + # interrupted run. Fall back to the latest checkpoint's config so the run can be + # resumed straight from the repo with `--config_path=`. + latest = find_latest_hub_checkpoint(model_id, token=token, revision=revision) + if latest is None: + raise FileNotFoundError( + f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" + ) from e + config_file = hf_hub_download( + filename=f"{latest}/{PRETRAINED_MODEL_DIR}/{TRAIN_CONFIG_NAME}", **dl_kwargs + ) cli_args = kwargs.pop("cli_args", []) # Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON). diff --git a/src/lerobot/utils/hub.py b/src/lerobot/utils/hub.py index 566701b31..38fed7420 100644 --- a/src/lerobot/utils/hub.py +++ b/src/lerobot/utils/hub.py @@ -20,9 +20,33 @@ from typing import Any, TypeVar from huggingface_hub import HfApi from huggingface_hub.utils import validate_hf_hub_args +from lerobot.utils.constants import CHECKPOINTS_DIR + T = TypeVar("T", bound="HubMixin") +def find_latest_hub_checkpoint( + repo_id: str, + *, + token: str | bool | None = None, + revision: str | None = None, +) -> str | None: + """Repo-relative path of the most recent checkpoint in a training repo. + + Training runs push checkpoints to ``checkpoints//`` (see + ``push_checkpoint_to_hub``). This lists those step dirs and returns + ``checkpoints/``, or ``None`` if the repo has no checkpoints. + """ + files = HfApi().list_repo_files(repo_id=repo_id, repo_type="model", revision=revision, token=token) + prefix = f"{CHECKPOINTS_DIR}/" + steps = { + name for f in files if f.startswith(prefix) and (name := f[len(prefix) :].split("/", 1)[0]).isdigit() + } + if not steps: + return None + return f"{CHECKPOINTS_DIR}/{max(steps, key=int)}" + + class HubMixin: """ A Mixin containing the functionality to push an object to the hub. diff --git a/tests/configs/test_resume_from_hub.py b/tests/configs/test_resume_from_hub.py new file mode 100644 index 000000000..2cc9fd7ae --- /dev/null +++ b/tests/configs/test_resume_from_hub.py @@ -0,0 +1,68 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import lerobot.configs.train as tc +from lerobot.configs.train import TrainPipelineConfig + + +class _FakeHTTPError(tc.HfHubHTTPError): + """HfHubHTTPError that can be raised without a real HTTP response object.""" + + def __init__(self): + pass + + +def test_from_pretrained_falls_back_to_latest_checkpoint_config(tmp_path, monkeypatch): + """A Hub repo with no root train_config.json (an interrupted run that only pushed + checkpoints/) resolves via the latest checkpoint's config.""" + # A real train_config.json written by save_pretrained, to be returned by the fallback. + parsed = tc.draccus.parse(TrainPipelineConfig, args=["--dataset.repo_id", "u/d"]) + cfg_file = tmp_path / "train_config.json" + parsed._save_pretrained(tmp_path) + assert cfg_file.is_file() + + calls = [] + + def fake_hf_hub_download(filename=None, **kwargs): + calls.append(filename) + if filename == "train_config.json": + raise _FakeHTTPError() # no root config + if filename == "checkpoints/000010/pretrained_model/train_config.json": + return str(cfg_file) + raise AssertionError(f"unexpected filename {filename}") + + monkeypatch.setattr(tc, "hf_hub_download", fake_hf_hub_download) + monkeypatch.setattr( + tc, "find_latest_hub_checkpoint", lambda repo_id, token=None, revision=None: "checkpoints/000010" + ) + + loaded = TrainPipelineConfig.from_pretrained("user/interrupted-run") + assert loaded.dataset.repo_id == "u/d" + # Tried the root config first, then fell back to the latest checkpoint's config. + assert calls == ["train_config.json", "checkpoints/000010/pretrained_model/train_config.json"] + + +def test_from_pretrained_raises_when_no_root_config_and_no_checkpoints(monkeypatch): + """No root config AND no checkpoints → a clear FileNotFoundError, not the raw HTTP error.""" + + def fake_hf_hub_download(filename=None, **kwargs): + raise _FakeHTTPError() + + monkeypatch.setattr(tc, "hf_hub_download", fake_hf_hub_download) + monkeypatch.setattr(tc, "find_latest_hub_checkpoint", lambda repo_id, token=None, revision=None: None) + + with pytest.raises(FileNotFoundError, match="train_config.json not found"): + TrainPipelineConfig.from_pretrained("user/empty-repo") diff --git a/tests/utils/test_hub.py b/tests/utils/test_hub.py new file mode 100644 index 000000000..a55631aeb --- /dev/null +++ b/tests/utils/test_hub.py @@ -0,0 +1,54 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from lerobot.utils.hub import find_latest_hub_checkpoint + + +def _patch_list_files(monkeypatch, files): + api = MagicMock() + api.list_repo_files.return_value = files + # HfApi is imported into lerobot.utils.hub at module load, so patch it there. + monkeypatch.setattr("lerobot.utils.hub.HfApi", lambda *a, **k: api) + return api + + +def test_find_latest_hub_checkpoint_picks_highest_step(monkeypatch): + _patch_list_files( + monkeypatch, + [ + "README.md", + "checkpoints/000500/pretrained_model/model.safetensors", + "checkpoints/000500/training_state/training_step.json", + "checkpoints/020000/pretrained_model/model.safetensors", + "checkpoints/001000/training_state/training_step.json", + ], + ) + # Numeric max, not lexicographic — "020000" beats "001000"/"000500". + assert find_latest_hub_checkpoint("u/run") == "checkpoints/020000" + + +def test_find_latest_hub_checkpoint_ignores_non_step_entries(monkeypatch): + _patch_list_files( + monkeypatch, + ["checkpoints/last/pretrained_model/model.safetensors", "config.json"], + ) + # "last" (a symlink target name) is not a numeric step → no resolvable checkpoint. + assert find_latest_hub_checkpoint("u/run") is None + + +def test_find_latest_hub_checkpoint_none_when_no_checkpoints(monkeypatch): + _patch_list_files(monkeypatch, ["config.json", "model.safetensors"]) + assert find_latest_hub_checkpoint("u/run") is None diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 461c5f031..ccd769bd0 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -17,6 +17,8 @@ from pathlib import Path from unittest.mock import MagicMock, Mock, patch +import pytest + from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, @@ -188,3 +190,36 @@ def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, mon push_checkpoint_to_hub(ckpt, "user/run") api.create_repo.assert_called_once() assert api.create_repo.call_args.kwargs["private"] is None + + +def test_resolve_resume_checkpoint_downloads_latest_and_links(tmp_path, monkeypatch): + from lerobot.common import train_utils + + out = tmp_path / "run" + + def fake_snapshot_download(repo_id, repo_type, allow_patterns, local_dir): + # Mimic the Hub layout the real download materializes locally. + assert allow_patterns == "checkpoints/020000/*" + (Path(local_dir) / "checkpoints" / "020000" / "pretrained_model").mkdir(parents=True) + return local_dir + + monkeypatch.setattr("lerobot.common.train_utils.snapshot_download", fake_snapshot_download) + monkeypatch.setattr( + "lerobot.common.train_utils.find_latest_hub_checkpoint", lambda repo_id: "checkpoints/020000" + ) + + checkpoint_dir = train_utils.resolve_resume_checkpoint("u/run", out) + + assert checkpoint_dir == out / CHECKPOINTS_DIR / "020000" + last = out / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK + assert last.is_symlink() + # `last` points at the downloaded step dir. + assert (last.parent / last.readlink()).resolve() == checkpoint_dir.resolve() + + +def test_resolve_resume_checkpoint_raises_without_checkpoints(tmp_path, monkeypatch): + from lerobot.common import train_utils + + monkeypatch.setattr("lerobot.common.train_utils.find_latest_hub_checkpoint", lambda repo_id: None) + with pytest.raises(FileNotFoundError, match="No checkpoint"): + train_utils.resolve_resume_checkpoint("u/run", tmp_path / "run")