Files
lerobot/tests/utils/test_hub.py
T
Nicolas Rabault 838ab9e234 feat(train): resume training from a Hub checkpoint
Allow --config_path to be a Hub repo id when resuming, not only a local path.
The latest checkpoint under checkpoints/<step>/ is downloaded into a fresh local
run dir and resumed from there (optimizer, scheduler, RNG and data order
restored as for a local resume). TrainPipelineConfig.from_pretrained falls back
to the latest checkpoint's train_config.json when a repo has no root config
(an interrupted run that only pushed checkpoints). The download is skipped when
dispatching remotely so the executor (local machine or HF Jobs pod) performs it.

- add find_latest_hub_checkpoint (utils/hub) and resolve_resume_checkpoint
  (common/train_utils), the symmetric download counterpart to
  push_checkpoint_to_hub
- unit tests for both helpers and the from_pretrained fallback
2026-06-25 21:50:39 +02:00

55 lines
2.1 KiB
Python

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
from lerobot.utils.hub import find_latest_hub_checkpoint
def _patch_list_files(monkeypatch, files):
api = MagicMock()
api.list_repo_files.return_value = files
# HfApi is imported into lerobot.utils.hub at module load, so patch it there.
monkeypatch.setattr("lerobot.utils.hub.HfApi", lambda *a, **k: api)
return api
def test_find_latest_hub_checkpoint_picks_highest_step(monkeypatch):
_patch_list_files(
monkeypatch,
[
"README.md",
"checkpoints/000500/pretrained_model/model.safetensors",
"checkpoints/000500/training_state/training_step.json",
"checkpoints/020000/pretrained_model/model.safetensors",
"checkpoints/001000/training_state/training_step.json",
],
)
# Numeric max, not lexicographic — "020000" beats "001000"/"000500".
assert find_latest_hub_checkpoint("u/run") == "checkpoints/020000"
def test_find_latest_hub_checkpoint_ignores_non_step_entries(monkeypatch):
_patch_list_files(
monkeypatch,
["checkpoints/last/pretrained_model/model.safetensors", "config.json"],
)
# "last" (a symlink target name) is not a numeric step → no resolvable checkpoint.
assert find_latest_hub_checkpoint("u/run") is None
def test_find_latest_hub_checkpoint_none_when_no_checkpoints(monkeypatch):
_patch_list_files(monkeypatch, ["config.json", "model.safetensors"])
assert find_latest_hub_checkpoint("u/run") is None