mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 20:57:28 +00:00
838ab9e234
Allow --config_path to be a Hub repo id when resuming, not only a local path. The latest checkpoint under checkpoints/<step>/ is downloaded into a fresh local run dir and resumed from there (optimizer, scheduler, RNG and data order restored as for a local resume). TrainPipelineConfig.from_pretrained falls back to the latest checkpoint's train_config.json when a repo has no root config (an interrupted run that only pushed checkpoints). The download is skipped when dispatching remotely so the executor (local machine or HF Jobs pod) performs it. - add find_latest_hub_checkpoint (utils/hub) and resolve_resume_checkpoint (common/train_utils), the symmetric download counterpart to push_checkpoint_to_hub - unit tests for both helpers and the from_pretrained fallback
55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
from lerobot.utils.hub import find_latest_hub_checkpoint
|
|
|
|
|
|
def _patch_list_files(monkeypatch, files):
|
|
api = MagicMock()
|
|
api.list_repo_files.return_value = files
|
|
# HfApi is imported into lerobot.utils.hub at module load, so patch it there.
|
|
monkeypatch.setattr("lerobot.utils.hub.HfApi", lambda *a, **k: api)
|
|
return api
|
|
|
|
|
|
def test_find_latest_hub_checkpoint_picks_highest_step(monkeypatch):
|
|
_patch_list_files(
|
|
monkeypatch,
|
|
[
|
|
"README.md",
|
|
"checkpoints/000500/pretrained_model/model.safetensors",
|
|
"checkpoints/000500/training_state/training_step.json",
|
|
"checkpoints/020000/pretrained_model/model.safetensors",
|
|
"checkpoints/001000/training_state/training_step.json",
|
|
],
|
|
)
|
|
# Numeric max, not lexicographic — "020000" beats "001000"/"000500".
|
|
assert find_latest_hub_checkpoint("u/run") == "checkpoints/020000"
|
|
|
|
|
|
def test_find_latest_hub_checkpoint_ignores_non_step_entries(monkeypatch):
|
|
_patch_list_files(
|
|
monkeypatch,
|
|
["checkpoints/last/pretrained_model/model.safetensors", "config.json"],
|
|
)
|
|
# "last" (a symlink target name) is not a numeric step → no resolvable checkpoint.
|
|
assert find_latest_hub_checkpoint("u/run") is None
|
|
|
|
|
|
def test_find_latest_hub_checkpoint_none_when_no_checkpoints(monkeypatch):
|
|
_patch_list_files(monkeypatch, ["config.json", "model.safetensors"])
|
|
assert find_latest_hub_checkpoint("u/run") is None
|