mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
Merge remote-tracking branch 'origin/main' into feature/basic-peft-support
This commit is contained in:
@@ -139,7 +139,6 @@ def test_async_inference_e2e(monkeypatch):
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
verify_robot_cameras=False,
|
||||
)
|
||||
|
||||
client = RobotClient(client_config)
|
||||
|
||||
@@ -51,7 +51,6 @@ def robot_client():
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
verify_robot_cameras=False,
|
||||
)
|
||||
|
||||
client = RobotClient(test_config)
|
||||
|
||||
@@ -181,6 +181,54 @@ def assert_dataset_iteration_works(aggr_ds):
|
||||
pass
|
||||
|
||||
|
||||
def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
"""Test that all video timestamps are within valid bounds for their respective video files.
|
||||
|
||||
This catches bugs where timestamps point to frames beyond the actual video length,
|
||||
which would cause "Invalid frame index" errors during data loading.
|
||||
"""
|
||||
try:
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
for ep_idx in range(aggr_ds.num_episodes):
|
||||
ep = aggr_ds.meta.episodes[ep_idx]
|
||||
|
||||
for vid_key in aggr_ds.meta.video_keys:
|
||||
from_ts = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
to_ts = ep[f"videos/{vid_key}/to_timestamp"]
|
||||
video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key)
|
||||
|
||||
if not video_path.exists():
|
||||
continue
|
||||
|
||||
from_frame_idx = round(from_ts * aggr_ds.fps)
|
||||
to_frame_idx = round(to_ts * aggr_ds.fps)
|
||||
|
||||
try:
|
||||
decoder = VideoDecoder(str(video_path))
|
||||
num_frames = len(decoder)
|
||||
|
||||
# Verify timestamps don't exceed video bounds
|
||||
assert from_frame_idx >= 0, (
|
||||
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0"
|
||||
)
|
||||
assert from_frame_idx < num_frames, (
|
||||
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})"
|
||||
)
|
||||
assert to_frame_idx <= num_frames, (
|
||||
f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})"
|
||||
)
|
||||
assert from_frame_idx < to_frame_idx, (
|
||||
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})"
|
||||
)
|
||||
except Exception as e:
|
||||
raise AssertionError(
|
||||
f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test basic aggregation functionality with standard parameters."""
|
||||
ds_0_num_frames = 400
|
||||
@@ -227,6 +275,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
@@ -277,6 +326,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
# Check that multiple files were actually created due to small size limits
|
||||
@@ -290,3 +340,43 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
if video_dir.exists():
|
||||
video_files = list(video_dir.rglob("*.mp4"))
|
||||
assert len(video_files) > 1, "Small file size limits should create multiple video files"
|
||||
|
||||
|
||||
def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
"""Regression test for video timestamp bug when merging datasets.
|
||||
|
||||
This test specifically checks that video timestamps are correctly calculated
|
||||
and accumulated when merging multiple datasets.
|
||||
"""
|
||||
datasets = []
|
||||
for i in range(3):
|
||||
ds = lerobot_dataset_factory(
|
||||
root=tmp_path / f"regression_{i}",
|
||||
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
|
||||
total_episodes=2,
|
||||
total_frames=100,
|
||||
)
|
||||
datasets.append(ds)
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds.repo_id for ds in datasets],
|
||||
roots=[ds.root for ds in datasets],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr",
|
||||
aggr_root=tmp_path / "regression_aggr",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "regression_aggr")
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
|
||||
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
|
||||
for i in range(len(aggr_ds)):
|
||||
item = aggr_ds[i]
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert key in item, f"Video key {key} missing from item {i}"
|
||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -806,6 +806,8 @@ def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
# Load the dataset and check episode indices
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
@@ -855,6 +857,8 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact
|
||||
dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
# Load and validate episode metadata
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
@@ -893,6 +897,8 @@ def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factor
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check data consistency - no gaps or overlaps
|
||||
@@ -944,6 +950,8 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory)
|
||||
dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check that statistics exist for all features
|
||||
@@ -989,6 +997,8 @@ def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Test episode boundaries
|
||||
@@ -1031,6 +1041,8 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": task})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check that all unique tasks are in the tasks metadata
|
||||
@@ -1056,3 +1068,134 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
# Check total number of tasks
|
||||
assert loaded_dataset.meta.total_tasks == len(unique_tasks)
|
||||
|
||||
|
||||
def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that resuming dataset recording preserves previously recorded episodes.
|
||||
|
||||
This test validates the critical resume functionality by:
|
||||
1. Recording initial episodes and finalizing
|
||||
2. Reopening the dataset
|
||||
3. Recording additional episodes
|
||||
4. Verifying all data (old + new) is intact
|
||||
|
||||
This specifically tests the bug fix where parquet files were being overwritten
|
||||
instead of appended to during resume.
|
||||
"""
|
||||
features = {
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
initial_episodes = 2
|
||||
frames_per_episode = 3
|
||||
|
||||
for ep_idx in range(initial_episodes):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([float(ep_idx), float(frame_idx)]),
|
||||
"action": torch.tensor([0.5, 0.5]),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset.meta.total_episodes == initial_episodes
|
||||
assert dataset.meta.total_frames == initial_episodes * frames_per_episode
|
||||
|
||||
dataset.finalize()
|
||||
initial_root = dataset.root
|
||||
initial_repo_id = dataset.repo_id
|
||||
del dataset
|
||||
|
||||
dataset_verify = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
|
||||
assert dataset_verify.meta.total_episodes == initial_episodes
|
||||
assert dataset_verify.meta.total_frames == initial_episodes * frames_per_episode
|
||||
assert len(dataset_verify.hf_dataset) == initial_episodes * frames_per_episode
|
||||
|
||||
for idx in range(len(dataset_verify.hf_dataset)):
|
||||
item = dataset_verify[idx]
|
||||
expected_ep = idx // frames_per_episode
|
||||
expected_frame = idx % frames_per_episode
|
||||
assert item["episode_index"].item() == expected_ep
|
||||
assert item["frame_index"].item() == expected_frame
|
||||
assert item["index"].item() == idx
|
||||
assert item["observation.state"][0].item() == float(expected_ep)
|
||||
assert item["observation.state"][1].item() == float(expected_frame)
|
||||
|
||||
del dataset_verify
|
||||
|
||||
# Phase 3: Resume recording - add more episodes
|
||||
dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
|
||||
|
||||
assert dataset_resumed.meta.total_episodes == initial_episodes
|
||||
assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode
|
||||
assert dataset_resumed.latest_episode is None # Not recording yet
|
||||
assert dataset_resumed.writer is None
|
||||
assert dataset_resumed.meta.writer is None
|
||||
|
||||
additional_episodes = 2
|
||||
for ep_idx in range(initial_episodes, initial_episodes + additional_episodes):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
dataset_resumed.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([float(ep_idx), float(frame_idx)]),
|
||||
"action": torch.tensor([0.5, 0.5]),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset_resumed.save_episode()
|
||||
|
||||
total_episodes = initial_episodes + additional_episodes
|
||||
total_frames = total_episodes * frames_per_episode
|
||||
assert dataset_resumed.meta.total_episodes == total_episodes
|
||||
assert dataset_resumed.meta.total_frames == total_frames
|
||||
|
||||
dataset_resumed.finalize()
|
||||
del dataset_resumed
|
||||
|
||||
dataset_final = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
|
||||
|
||||
assert dataset_final.meta.total_episodes == total_episodes
|
||||
assert dataset_final.meta.total_frames == total_frames
|
||||
assert len(dataset_final.hf_dataset) == total_frames
|
||||
|
||||
for idx in range(total_frames):
|
||||
item = dataset_final[idx]
|
||||
expected_ep = idx // frames_per_episode
|
||||
expected_frame = idx % frames_per_episode
|
||||
|
||||
assert item["episode_index"].item() == expected_ep, (
|
||||
f"Frame {idx}: wrong episode_index. Expected {expected_ep}, got {item['episode_index'].item()}"
|
||||
)
|
||||
assert item["frame_index"].item() == expected_frame, (
|
||||
f"Frame {idx}: wrong frame_index. Expected {expected_frame}, got {item['frame_index'].item()}"
|
||||
)
|
||||
assert item["index"].item() == idx, (
|
||||
f"Frame {idx}: wrong index. Expected {idx}, got {item['index'].item()}"
|
||||
)
|
||||
|
||||
# Verify data integrity
|
||||
assert item["observation.state"][0].item() == float(expected_ep), (
|
||||
f"Frame {idx}: wrong observation.state[0]. Expected {float(expected_ep)}, "
|
||||
f"got {item['observation.state'][0].item()}"
|
||||
)
|
||||
assert item["observation.state"][1].item() == float(expected_frame), (
|
||||
f"Frame {idx}: wrong observation.state[1]. Expected {float(expected_frame)}, "
|
||||
f"got {item['observation.state'][1].item()}"
|
||||
)
|
||||
|
||||
assert len(dataset_final.meta.episodes) == total_episodes
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_metadata = dataset_final.meta.episodes[ep_idx]
|
||||
assert ep_metadata["episode_index"] == ep_idx
|
||||
assert ep_metadata["length"] == frames_per_episode
|
||||
assert ep_metadata["tasks"] == [f"task_{ep_idx}"]
|
||||
|
||||
expected_from = ep_idx * frames_per_episode
|
||||
expected_to = (ep_idx + 1) * frames_per_episode
|
||||
assert ep_metadata["dataset_from_index"] == expected_from
|
||||
assert ep_metadata["dataset_to_index"] == expected_to
|
||||
|
||||
@@ -95,7 +95,6 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
@pytest.mark.parametrize(
|
||||
"ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs",
|
||||
[
|
||||
("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}),
|
||||
("lerobot/pusht", "pusht", {}, "diffusion", {}),
|
||||
("lerobot/pusht", "pusht", {}, "vqbet", {}),
|
||||
("lerobot/pusht", "pusht", {}, "act", {}),
|
||||
@@ -328,8 +327,6 @@ def test_multikey_construction(multikey: bool):
|
||||
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
|
||||
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
||||
# to test with `policy.use_mpc=false`.
|
||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
|
||||
# ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
|
||||
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
||||
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Multi-GPU Training Tests
|
||||
|
||||
This module tests multi-GPU training functionality with accelerate.
|
||||
These tests are designed to run on machines with 2+ GPUs and are executed
|
||||
in the nightly CI workflow.
|
||||
|
||||
The tests automatically generate accelerate configs and launch training
|
||||
with subprocess to properly test the distributed training environment.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def get_num_available_gpus():
|
||||
"""Returns the number of available GPUs."""
|
||||
if not torch.cuda.is_available():
|
||||
return 0
|
||||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def download_dataset(repo_id, episodes):
|
||||
"""
|
||||
Pre-download dataset to avoid race conditions in multi-GPU training.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID
|
||||
episodes: List of episode indices to download
|
||||
"""
|
||||
# Simply instantiating the dataset will download it
|
||||
_ = LeRobotDataset(repo_id, episodes=episodes)
|
||||
print(f"Dataset {repo_id} downloaded successfully")
|
||||
|
||||
|
||||
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||
"""
|
||||
Helper function to run training with accelerate launch.
|
||||
|
||||
Args:
|
||||
config_args: List of config arguments to pass to lerobot_train.py
|
||||
num_processes: Number of processes (GPUs) to use
|
||||
temp_dir: Temporary directory for outputs
|
||||
|
||||
Returns:
|
||||
subprocess.CompletedProcess result
|
||||
"""
|
||||
|
||||
config_path = Path(temp_dir) / "accelerate_config.yaml"
|
||||
|
||||
# Write YAML config
|
||||
with open(config_path, "w") as f:
|
||||
f.write("compute_environment: LOCAL_MACHINE\n")
|
||||
f.write("distributed_type: MULTI_GPU\n")
|
||||
f.write("mixed_precision: 'no'\n")
|
||||
f.write(f"num_processes: {num_processes}\n")
|
||||
f.write("use_cpu: false\n")
|
||||
f.write("gpu_ids: all\n")
|
||||
f.write("downcast_bf16: 'no'\n")
|
||||
f.write("machine_rank: 0\n")
|
||||
f.write("main_training_function: main\n")
|
||||
f.write("num_machines: 1\n")
|
||||
f.write("rdzv_backend: static\n")
|
||||
f.write("same_network: true\n")
|
||||
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--config_file",
|
||||
str(config_path),
|
||||
"-m",
|
||||
"lerobot.scripts.lerobot_train",
|
||||
] + config_args
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env={**os.environ, "CUDA_VISIBLE_DEVICES": ",".join(map(str, range(num_processes)))},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
get_num_available_gpus() < 2,
|
||||
reason="Multi-GPU tests require at least 2 GPUs",
|
||||
)
|
||||
class TestMultiGPUTraining:
|
||||
"""Test suite for multi-GPU training functionality."""
|
||||
|
||||
def test_basic_multi_gpu_training(self):
|
||||
"""
|
||||
Test that basic multi-GPU training runs successfully.
|
||||
Verifies that the training completes without errors.
|
||||
"""
|
||||
# Pre-download dataset to avoid race conditions
|
||||
download_dataset("lerobot/pusht", episodes=[0])
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
output_dir = Path(temp_dir) / "outputs"
|
||||
|
||||
config_args = [
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.type=act",
|
||||
"--policy.device=cuda",
|
||||
"--policy.push_to_hub=false",
|
||||
f"--output_dir={output_dir}",
|
||||
"--batch_size=4",
|
||||
"--steps=10",
|
||||
"--eval_freq=-1",
|
||||
"--log_freq=5",
|
||||
"--save_freq=10",
|
||||
"--seed=42",
|
||||
"--num_workers=0",
|
||||
]
|
||||
|
||||
result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir)
|
||||
|
||||
# Check that training completed successfully
|
||||
assert result.returncode == 0, (
|
||||
f"Multi-GPU training failed with return code {result.returncode}\n"
|
||||
f"STDOUT:\n{result.stdout}\n"
|
||||
f"STDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
# Verify checkpoint was saved
|
||||
checkpoints_dir = output_dir / "checkpoints"
|
||||
assert checkpoints_dir.exists(), "Checkpoints directory was not created"
|
||||
|
||||
# Verify that training completed
|
||||
assert "End of training" in result.stdout or "End of training" in result.stderr
|
||||
|
||||
def test_checkpoint_saving_multi_gpu(self):
|
||||
"""
|
||||
Test that checkpoints are correctly saved during multi-GPU training.
|
||||
Only the main process (rank 0) should save checkpoints.
|
||||
"""
|
||||
# Pre-download dataset to avoid race conditions
|
||||
download_dataset("lerobot/pusht", episodes=[0])
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
output_dir = Path(temp_dir) / "outputs"
|
||||
|
||||
config_args = [
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.type=act",
|
||||
"--policy.device=cuda",
|
||||
"--policy.push_to_hub=false",
|
||||
f"--output_dir={output_dir}",
|
||||
"--batch_size=4",
|
||||
"--steps=20",
|
||||
"--eval_freq=-1",
|
||||
"--log_freq=5",
|
||||
"--save_freq=10",
|
||||
"--seed=42",
|
||||
"--num_workers=0",
|
||||
]
|
||||
|
||||
result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)
|
||||
|
||||
assert result.returncode == 0, (
|
||||
f"Training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
# Verify checkpoint directory exists
|
||||
checkpoints_dir = output_dir / "checkpoints"
|
||||
assert checkpoints_dir.exists(), "Checkpoints directory not created"
|
||||
|
||||
# Count checkpoint directories (should have checkpoint at step 10 and 20)
|
||||
checkpoint_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir()]
|
||||
assert len(checkpoint_dirs) >= 1, f"Expected at least 1 checkpoint, found {len(checkpoint_dirs)}"
|
||||
|
||||
# Verify checkpoint contents
|
||||
for checkpoint_dir in checkpoint_dirs:
|
||||
# Check for model files
|
||||
model_files = list(checkpoint_dir.rglob("*.safetensors"))
|
||||
assert len(model_files) > 0, f"No model files in checkpoint {checkpoint_dir}"
|
||||
|
||||
# Check for training state
|
||||
training_state_dir = checkpoint_dir / "training_state"
|
||||
assert training_state_dir.exists(), f"No training state in checkpoint {checkpoint_dir}"
|
||||
|
||||
# Verify optimizer state exists
|
||||
optimizer_state = training_state_dir / "optimizer_state.safetensors"
|
||||
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
|
||||
Reference in New Issue
Block a user