mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
1aa937aad2
compute_sampler_state mapped a checkpointed step back to (epoch, start_index) using the *current* num_processes, but the number of sampler positions a step consumes scales with the world size that produced it. Resuming on a different GPU count therefore landed on the wrong epoch/offset, silently re-seeing or skipping data. Record num_processes in training_step.json at checkpoint time and feed the checkpoint's value into compute_sampler_state on resume, so the data order resumes at the right position regardless of the new world size. Warn when the world size changed (the global offset is correct, but per-rank sample-exactness needs the same topology). Old checkpoints without the field fall back to the current world size. Also document compute_sampler_state's assumptions explicitly: num_processes / batch_size must match the checkpointing run, and accelerate's even_batches=True padding is mirrored by the ceil(... / num_processes) term. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Co-authored-by: Cursor <cursoragent@cursor.com>
127 lines
4.4 KiB
Python
127 lines
4.4 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
|
|
from lerobot.common.train_utils import (
|
|
get_step_checkpoint_dir,
|
|
get_step_identifier,
|
|
load_training_num_processes,
|
|
load_training_state,
|
|
load_training_step,
|
|
save_checkpoint,
|
|
save_training_state,
|
|
save_training_step,
|
|
update_last_checkpoint,
|
|
)
|
|
from lerobot.utils.constants import (
|
|
CHECKPOINTS_DIR,
|
|
LAST_CHECKPOINT_LINK,
|
|
OPTIMIZER_PARAM_GROUPS,
|
|
OPTIMIZER_STATE,
|
|
RNG_STATE,
|
|
SCHEDULER_STATE,
|
|
TRAINING_STATE_DIR,
|
|
TRAINING_STEP,
|
|
)
|
|
|
|
|
|
def test_get_step_identifier():
|
|
assert get_step_identifier(5, 1000) == "000005"
|
|
assert get_step_identifier(123, 100_000) == "000123"
|
|
assert get_step_identifier(456789, 1_000_000) == "0456789"
|
|
|
|
|
|
def test_get_step_checkpoint_dir():
|
|
output_dir = Path("/checkpoints")
|
|
step_dir = get_step_checkpoint_dir(output_dir, 1000, 5)
|
|
assert step_dir == output_dir / CHECKPOINTS_DIR / "000005"
|
|
|
|
|
|
def test_save_load_training_step(tmp_path):
|
|
save_training_step(5000, tmp_path)
|
|
assert (tmp_path / TRAINING_STEP).is_file()
|
|
|
|
|
|
def test_load_training_step(tmp_path):
|
|
step = 5000
|
|
save_training_step(step, tmp_path)
|
|
loaded_step = load_training_step(tmp_path)
|
|
assert loaded_step == step
|
|
|
|
|
|
def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler):
|
|
save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4)
|
|
assert load_training_num_processes(tmp_path) == 4
|
|
|
|
|
|
def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler):
|
|
# Checkpoints written before the world size was recorded must still load (back-compat).
|
|
save_training_state(tmp_path, 10, optimizer, scheduler)
|
|
assert load_training_num_processes(tmp_path) is None
|
|
|
|
|
|
def test_update_last_checkpoint(tmp_path):
|
|
checkpoint = tmp_path / "0005"
|
|
checkpoint.mkdir()
|
|
update_last_checkpoint(checkpoint)
|
|
last_checkpoint = tmp_path / LAST_CHECKPOINT_LINK
|
|
assert last_checkpoint.is_symlink()
|
|
assert last_checkpoint.resolve() == checkpoint
|
|
|
|
|
|
@patch("lerobot.common.train_utils.save_training_state")
|
|
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
|
policy = Mock()
|
|
cfg = Mock()
|
|
save_checkpoint(tmp_path, 10, cfg, policy, optimizer)
|
|
policy.save_pretrained.assert_called_once()
|
|
cfg.save_pretrained.assert_called_once()
|
|
mock_save_training_state.assert_called_once()
|
|
|
|
|
|
@patch("lerobot.common.train_utils.save_training_state")
|
|
def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer):
|
|
policy = Mock()
|
|
policy.config = Mock()
|
|
policy.config.save_pretrained = Mock()
|
|
cfg = Mock()
|
|
cfg.use_peft = True
|
|
save_checkpoint(tmp_path, 10, cfg, policy, optimizer)
|
|
policy.save_pretrained.assert_called_once()
|
|
cfg.save_pretrained.assert_called_once()
|
|
policy.config.save_pretrained.assert_called_once()
|
|
mock_save_training_state.assert_called_once()
|
|
|
|
|
|
def test_save_training_state(tmp_path, optimizer, scheduler):
|
|
save_training_state(tmp_path, 10, optimizer, scheduler)
|
|
assert (tmp_path / TRAINING_STATE_DIR).is_dir()
|
|
assert (tmp_path / TRAINING_STATE_DIR / TRAINING_STEP).is_file()
|
|
assert (tmp_path / TRAINING_STATE_DIR / RNG_STATE).is_file()
|
|
assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_STATE).is_file()
|
|
assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_PARAM_GROUPS).is_file()
|
|
assert (tmp_path / TRAINING_STATE_DIR / SCHEDULER_STATE).is_file()
|
|
|
|
|
|
def test_save_load_training_state(tmp_path, optimizer, scheduler):
|
|
save_training_state(tmp_path, 10, optimizer, scheduler)
|
|
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler)
|
|
assert loaded_step == 10
|
|
assert loaded_optimizer is optimizer
|
|
assert loaded_scheduler is scheduler
|