mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
41166b39fb
* fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync In distributed training, accelerate can only synchronize the shuffle permutation across ranks when the sampler exposes a generator attribute. EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch shards relied on every rank's global CPU RNG staying in lockstep forever; any rank-asymmetric RNG consumption (e.g. eval rollouts on the main process only) silently desynced the permutations and ranks trained on overlapping/missing samples. * fix(train): seed sampler generator and gate dataset download per node - Pass a generator seeded with cfg.seed to EpisodeAwareSampler so accelerator.prepare registers it as the synchronized RNG and the shuffle order is reproducible. - Gate the initial make_dataset call on is_local_main_process instead of is_main_process: the global main process only exists on node 0, so on every other node all local ranks were downloading the dataset and building the Arrow cache concurrently.
164 lines
6.5 KiB
Python
164 lines
6.5 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
|
|
|
from datasets import Dataset # noqa: E402
|
|
|
|
from lerobot.datasets.io_utils import (
|
|
hf_transform_to_torch,
|
|
)
|
|
from lerobot.datasets.sampler import EpisodeAwareSampler
|
|
|
|
|
|
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
|
"""Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}."""
|
|
episode_data_index: dict[str, list[int]] = {"from": [], "to": []}
|
|
current_episode = None
|
|
if len(hf_dataset) == 0:
|
|
return {"from": torch.tensor([]), "to": torch.tensor([])}
|
|
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
|
if episode_idx != current_episode:
|
|
episode_data_index["from"].append(idx)
|
|
if current_episode is not None:
|
|
episode_data_index["to"].append(idx)
|
|
current_episode = episode_idx
|
|
episode_data_index["to"].append(idx + 1)
|
|
return {k: torch.tensor(v) for k, v in episode_data_index.items()}
|
|
|
|
|
|
def test_drop_n_first_frames():
|
|
dataset = Dataset.from_dict(
|
|
{
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
},
|
|
)
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
episode_data_index = calculate_episode_data_index(dataset)
|
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_first_frames=1)
|
|
assert sampler.indices == [1, 4, 5]
|
|
assert len(sampler) == 3
|
|
assert list(sampler) == [1, 4, 5]
|
|
|
|
|
|
def test_drop_n_last_frames():
|
|
dataset = Dataset.from_dict(
|
|
{
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
},
|
|
)
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
episode_data_index = calculate_episode_data_index(dataset)
|
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_last_frames=1)
|
|
assert sampler.indices == [0, 3, 4]
|
|
assert len(sampler) == 3
|
|
assert list(sampler) == [0, 3, 4]
|
|
|
|
|
|
def test_episode_indices_to_use():
|
|
dataset = Dataset.from_dict(
|
|
{
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
},
|
|
)
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
episode_data_index = calculate_episode_data_index(dataset)
|
|
sampler = EpisodeAwareSampler(
|
|
episode_data_index["from"], episode_data_index["to"], episode_indices_to_use=[0, 2]
|
|
)
|
|
assert sampler.indices == [0, 1, 3, 4, 5]
|
|
assert len(sampler) == 5
|
|
assert list(sampler) == [0, 1, 3, 4, 5]
|
|
|
|
|
|
def test_shuffle():
|
|
dataset = Dataset.from_dict(
|
|
{
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
},
|
|
)
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
episode_data_index = calculate_episode_data_index(dataset)
|
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=False)
|
|
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
|
assert len(sampler) == 6
|
|
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=True)
|
|
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
|
assert len(sampler) == 6
|
|
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
|
|
|
|
|
def test_shuffle_with_generator_is_deterministic():
|
|
# Two samplers shuffling with same-seed generators must yield identical permutations.
|
|
# This is what keeps batch shards disjoint across ranks in distributed training, where
|
|
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
|
|
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
|
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
|
assert list(sampler_a) == list(sampler_b)
|
|
|
|
# Desyncing the global RNG must not affect the permutation.
|
|
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
|
order_before = list(sampler_c)
|
|
sampler_c.generator.manual_seed(42)
|
|
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
|
assert list(sampler_c) == order_before
|
|
|
|
|
|
def test_generator_attribute_defaults_to_none():
|
|
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
|
|
# so the attribute must exist even when no generator is passed.
|
|
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
|
|
assert sampler.generator is None
|
|
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
|
|
|
|
|
def test_negative_drop_first_frames_raises():
|
|
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
|
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
|
|
|
|
|
def test_negative_drop_last_frames_raises():
|
|
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
|
|
EpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
|
|
|
|
|
|
def test_all_episodes_dropped_raises():
|
|
# All episodes have 1 frame, drop_n_first_frames=1 removes all
|
|
with pytest.raises(ValueError, match="No valid frames remain"):
|
|
EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
|
|
|
|
|
|
def test_partial_episode_drop_warns(caplog):
|
|
# Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept)
|
|
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
|
|
sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1)
|
|
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
|
assert sampler.indices == [2, 3, 4, 5]
|
|
assert "Episode 0" in caplog.text
|