mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-12 05:59:53 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 69e8ab38bd | |||
| 926fb9c31e | |||
| 3f6909fb63 | |||
| 3d0e8681f3 | |||
| 1e2057e3be | |||
| c094f40868 | |||
| 1aa937aad2 | |||
| 7a62235bac | |||
| 81f0ca9ce4 | |||
| 29ca0f53d9 | |||
| b2d5d4ccfc | |||
| 32b0d7d1ef | |||
| 7416b714c0 | |||
| 6fa495c6b0 | |||
| 72e093dbff | |||
| 3d262a6c9e |
@@ -49,8 +49,19 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
|
|||||||
return output_dir / CHECKPOINTS_DIR / step_identifier
|
return output_dir / CHECKPOINTS_DIR / step_identifier
|
||||||
|
|
||||||
|
|
||||||
def save_training_step(step: int, save_dir: Path) -> None:
|
def save_training_step(
|
||||||
write_json({"step": step}, save_dir / TRAINING_STEP)
|
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
|
||||||
|
) -> None:
|
||||||
|
state: dict = {"step": step}
|
||||||
|
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
|
||||||
|
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
|
||||||
|
# produced `step`, since both scale how many sampler positions a step consumes (see
|
||||||
|
# compute_sampler_state).
|
||||||
|
if num_processes is not None:
|
||||||
|
state["num_processes"] = num_processes
|
||||||
|
if batch_size is not None:
|
||||||
|
state["batch_size"] = batch_size
|
||||||
|
write_json(state, save_dir / TRAINING_STEP)
|
||||||
|
|
||||||
|
|
||||||
def load_training_step(save_dir: Path) -> int:
|
def load_training_step(save_dir: Path) -> int:
|
||||||
@@ -58,6 +69,16 @@ def load_training_step(save_dir: Path) -> int:
|
|||||||
return training_step["step"]
|
return training_step["step"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
|
||||||
|
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
|
||||||
|
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
|
||||||
|
|
||||||
|
|
||||||
|
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
|
||||||
|
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
|
||||||
|
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
|
||||||
|
|
||||||
|
|
||||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||||
if last_checkpoint_dir.is_symlink():
|
if last_checkpoint_dir.is_symlink():
|
||||||
@@ -75,6 +96,8 @@ def save_checkpoint(
|
|||||||
scheduler: LRScheduler | None = None,
|
scheduler: LRScheduler | None = None,
|
||||||
preprocessor: PolicyProcessorPipeline | None = None,
|
preprocessor: PolicyProcessorPipeline | None = None,
|
||||||
postprocessor: PolicyProcessorPipeline | None = None,
|
postprocessor: PolicyProcessorPipeline | None = None,
|
||||||
|
num_processes: int | None = None,
|
||||||
|
batch_size: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This function creates the following directory structure:
|
"""This function creates the following directory structure:
|
||||||
|
|
||||||
@@ -100,6 +123,10 @@ def save_checkpoint(
|
|||||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||||
|
num_processes (int | None, optional): Distributed world size to record for sample-exact
|
||||||
|
resume. Defaults to None (not recorded).
|
||||||
|
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||||
|
resume. Defaults to None (not recorded).
|
||||||
"""
|
"""
|
||||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||||
policy.save_pretrained(pretrained_dir)
|
policy.save_pretrained(pretrained_dir)
|
||||||
@@ -112,7 +139,9 @@ def save_checkpoint(
|
|||||||
preprocessor.save_pretrained(pretrained_dir)
|
preprocessor.save_pretrained(pretrained_dir)
|
||||||
if postprocessor is not None:
|
if postprocessor is not None:
|
||||||
postprocessor.save_pretrained(pretrained_dir)
|
postprocessor.save_pretrained(pretrained_dir)
|
||||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
save_training_state(
|
||||||
|
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_training_state(
|
def save_training_state(
|
||||||
@@ -120,6 +149,8 @@ def save_training_state(
|
|||||||
train_step: int,
|
train_step: int,
|
||||||
optimizer: Optimizer | None = None,
|
optimizer: Optimizer | None = None,
|
||||||
scheduler: LRScheduler | None = None,
|
scheduler: LRScheduler | None = None,
|
||||||
|
num_processes: int | None = None,
|
||||||
|
batch_size: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||||
@@ -131,10 +162,12 @@ def save_training_state(
|
|||||||
Defaults to None.
|
Defaults to None.
|
||||||
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
||||||
|
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
|
||||||
"""
|
"""
|
||||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
save_training_step(train_step, save_dir)
|
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
|
||||||
save_rng_state(save_dir)
|
save_rng_state(save_dir)
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
save_optimizer_state(optimizer, save_dir)
|
save_optimizer_state(optimizer, save_dir)
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
|
|||||||
from .multi_dataset import MultiLeRobotDataset
|
from .multi_dataset import MultiLeRobotDataset
|
||||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||||
from .sampler import EpisodeAwareSampler
|
from .sampler import EpisodeAwareSampler, compute_sampler_state
|
||||||
from .streaming_dataset import StreamingLeRobotDataset
|
from .streaming_dataset import StreamingLeRobotDataset
|
||||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||||
from .video_utils import VideoEncodingManager
|
from .video_utils import VideoEncodingManager
|
||||||
@@ -82,6 +82,7 @@ __all__ = [
|
|||||||
"aggregate_stats",
|
"aggregate_stats",
|
||||||
"convert_image_to_video_dataset",
|
"convert_image_to_video_dataset",
|
||||||
"create_initial_features",
|
"create_initial_features",
|
||||||
|
"compute_sampler_state",
|
||||||
"create_lerobot_dataset_card",
|
"create_lerobot_dataset_card",
|
||||||
"column_for_style",
|
"column_for_style",
|
||||||
"delete_episodes",
|
"delete_episodes",
|
||||||
|
|||||||
+122
-38
@@ -14,14 +14,36 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EpisodeAwareSampler:
|
class EpisodeAwareSampler:
|
||||||
|
"""Sampler over episode frames that stores only per-episode boundaries.
|
||||||
|
|
||||||
|
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
|
||||||
|
instead of materializing a Python list of every frame index.
|
||||||
|
|
||||||
|
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
|
||||||
|
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
|
||||||
|
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
|
||||||
|
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
|
||||||
|
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
|
||||||
|
resumed epoch, `__len__` still reports the full length.
|
||||||
|
|
||||||
|
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
|
||||||
|
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
|
||||||
|
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
|
||||||
|
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
|
||||||
|
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
|
||||||
|
starts (e.g. on resume or in tests).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_from_indices: list[int],
|
dataset_from_indices: list[int],
|
||||||
@@ -30,63 +52,125 @@ class EpisodeAwareSampler:
|
|||||||
drop_n_first_frames: int = 0,
|
drop_n_first_frames: int = 0,
|
||||||
drop_n_last_frames: int = 0,
|
drop_n_last_frames: int = 0,
|
||||||
shuffle: bool = False,
|
shuffle: bool = False,
|
||||||
generator: torch.Generator | None = None,
|
seed: int = 0,
|
||||||
):
|
):
|
||||||
"""Sampler that optionally incorporates episode boundary information.
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
dataset_from_indices: Start index of each episode in the dataset.
|
||||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
dataset_to_indices: End index of each episode in the dataset.
|
||||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
episode_indices_to_use: Episode indices to use; None means all.
|
||||||
Assumes that episodes are indexed from 0 to N-1.
|
drop_n_first_frames: Frames to drop from the start of each episode.
|
||||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
drop_n_last_frames: Frames to drop from the end of each episode.
|
||||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
|
||||||
shuffle: Whether to shuffle the indices.
|
shuffle: Whether to shuffle the indices.
|
||||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
seed: Seed the permutation is derived from (together with the epoch).
|
||||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
|
||||||
every rank draws the same permutation and batch shards stay disjoint. When
|
|
||||||
None, shuffling falls back to the global torch RNG.
|
|
||||||
"""
|
"""
|
||||||
if drop_n_first_frames < 0:
|
if drop_n_first_frames < 0:
|
||||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||||
if drop_n_last_frames < 0:
|
if drop_n_last_frames < 0:
|
||||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||||
|
|
||||||
indices = []
|
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
|
||||||
for episode_idx, (start_index, end_index) in enumerate(
|
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
|
||||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
if from_indices.shape != to_indices.shape:
|
||||||
):
|
raise ValueError(
|
||||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
f"dataset_from_indices and dataset_to_indices must have the same length, "
|
||||||
ep_length = end_index - start_index
|
f"got {len(from_indices)} and {len(to_indices)}"
|
||||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
)
|
||||||
logger.warning(
|
|
||||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
|
||||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
|
||||||
episode_idx,
|
|
||||||
ep_length,
|
|
||||||
drop_n_first_frames,
|
|
||||||
drop_n_last_frames,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
|
||||||
|
|
||||||
if not indices:
|
used = np.ones(len(from_indices), dtype=bool)
|
||||||
|
if episode_indices_to_use is not None:
|
||||||
|
used = np.zeros(len(from_indices), dtype=bool)
|
||||||
|
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
|
||||||
|
|
||||||
|
starts = from_indices + drop_n_first_frames
|
||||||
|
lengths = to_indices - drop_n_last_frames - starts
|
||||||
|
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
|
||||||
|
logger.warning(
|
||||||
|
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||||
|
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||||
|
episode_idx,
|
||||||
|
to_indices[episode_idx] - from_indices[episode_idx],
|
||||||
|
drop_n_first_frames,
|
||||||
|
drop_n_last_frames,
|
||||||
|
)
|
||||||
|
used &= lengths > 0
|
||||||
|
if not used.any():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||||
"All episodes were either filtered out or had too few frames."
|
"All episodes were either filtered out or had too few frames."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.indices = indices
|
self._starts = starts[used]
|
||||||
|
self._cum_lengths = np.cumsum(lengths[used])
|
||||||
|
self._num_frames = int(self._cum_lengths[-1])
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.generator = generator
|
self.seed = seed
|
||||||
|
self._epoch = 0
|
||||||
|
self._start_index = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def indices(self) -> list[int]:
|
||||||
|
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
|
||||||
|
return [self._frame_index(k) for k in range(self._num_frames)]
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int) -> None:
|
||||||
|
self._epoch = epoch
|
||||||
|
|
||||||
|
def state_dict(self) -> dict:
|
||||||
|
return {"epoch": self._epoch, "start_index": self._start_index}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict) -> None:
|
||||||
|
self._epoch = state["epoch"]
|
||||||
|
self._start_index = state["start_index"]
|
||||||
|
|
||||||
|
def _epoch_generator(self, epoch: int) -> torch.Generator:
|
||||||
|
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
|
||||||
|
# and reproduces identically on every rank without touching the global RNG.
|
||||||
|
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
|
||||||
|
return torch.Generator().manual_seed(epoch_seed)
|
||||||
|
|
||||||
|
def _frame_index(self, position: int) -> int:
|
||||||
|
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||||
|
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
||||||
|
return int(self._starts[episode]) + position_in_episode
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[int]:
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||||
|
epoch, start = self._epoch, self._start_index
|
||||||
|
self._epoch += 1
|
||||||
|
self._start_index = 0
|
||||||
|
return self._iter_epoch(epoch, start)
|
||||||
|
|
||||||
|
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
|
||||||
if self.shuffle:
|
if self.shuffle:
|
||||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
|
||||||
yield self.indices[i]
|
for k in range(start, self._num_frames):
|
||||||
|
yield self._frame_index(int(order[k]))
|
||||||
else:
|
else:
|
||||||
for i in self.indices:
|
for k in range(start, self._num_frames):
|
||||||
yield i
|
yield self._frame_index(k)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.indices)
|
return self._num_frames
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
|
||||||
|
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
|
||||||
|
|
||||||
|
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
|
||||||
|
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
|
||||||
|
per epoch (`even_batches` padding included). The start index provably stays below
|
||||||
|
`num_frames`; the `min` is defensive.
|
||||||
|
|
||||||
|
Assumptions (resume is only sample-exact when they hold):
|
||||||
|
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
|
||||||
|
many positions a step consumes, so the epoch/offset are wrong if either changed. The
|
||||||
|
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
|
||||||
|
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
|
||||||
|
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
|
||||||
|
the boundary is off.
|
||||||
|
"""
|
||||||
|
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
|
||||||
|
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
|
||||||
|
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
|
||||||
|
return {"epoch": epoch, "start_index": start_index}
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ from tqdm import tqdm
|
|||||||
from lerobot.common.train_utils import (
|
from lerobot.common.train_utils import (
|
||||||
get_step_checkpoint_dir,
|
get_step_checkpoint_dir,
|
||||||
get_step_identifier,
|
get_step_identifier,
|
||||||
|
load_training_batch_size,
|
||||||
|
load_training_num_processes,
|
||||||
load_training_state,
|
load_training_state,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
update_last_checkpoint,
|
update_last_checkpoint,
|
||||||
@@ -43,7 +45,7 @@ from lerobot.common.train_utils import (
|
|||||||
from lerobot.common.wandb_utils import WandBLogger
|
from lerobot.common.wandb_utils import WandBLogger
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
|
||||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||||
@@ -232,18 +234,17 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
# Dataset loading synchronization: each node's local main process downloads first to avoid
|
# Dataset loading synchronization: the global main process downloads once to the shared
|
||||||
# race conditions (the global main process only exists on node 0, so gating on it would let
|
# dataset root, then a barrier lets every other rank read the already-populated copy.
|
||||||
# all ranks of the other nodes download and build the Arrow cache concurrently).
|
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
|
||||||
if accelerator.is_local_main_process:
|
if is_main_process:
|
||||||
if is_main_process:
|
logging.info("Creating dataset")
|
||||||
logging.info("Creating dataset")
|
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# Now all other processes can safely load the dataset from the local cache
|
# Other ranks read from the shared copy populated by the main process.
|
||||||
if not accelerator.is_local_main_process:
|
if not is_main_process:
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
@@ -387,22 +388,47 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
if not cfg.dataset.streaming:
|
||||||
|
# All non-streaming (map-style) datasets use EpisodeAwareSampler.
|
||||||
|
# The order is a pure function of (seed, epoch), so every rank independently produces the
|
||||||
|
# same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard
|
||||||
|
# without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact.
|
||||||
shuffle = False
|
shuffle = False
|
||||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
|
||||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
|
||||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
|
||||||
sampler_generator = torch.Generator()
|
|
||||||
if cfg.seed is not None:
|
|
||||||
sampler_generator.manual_seed(cfg.seed)
|
|
||||||
sampler = EpisodeAwareSampler(
|
sampler = EpisodeAwareSampler(
|
||||||
dataset.meta.episodes["dataset_from_index"],
|
dataset.meta.episodes["dataset_from_index"],
|
||||||
dataset.meta.episodes["dataset_to_index"],
|
dataset.meta.episodes["dataset_to_index"],
|
||||||
episode_indices_to_use=dataset.episodes,
|
episode_indices_to_use=dataset.episodes,
|
||||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
generator=sampler_generator,
|
seed=cfg.seed if cfg.seed is not None else 0,
|
||||||
)
|
)
|
||||||
|
if cfg.resume and step > 0:
|
||||||
|
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
|
||||||
|
# use the values recorded in the checkpoint (falling back to the current ones for older
|
||||||
|
# ckpts that did not store them).
|
||||||
|
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
|
||||||
|
saved_batch_size = load_training_batch_size(cfg.checkpoint_path)
|
||||||
|
ckpt_num_processes = saved_num_processes or accelerator.num_processes
|
||||||
|
ckpt_batch_size = saved_batch_size or cfg.batch_size
|
||||||
|
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
|
||||||
|
logging.warning(
|
||||||
|
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
|
||||||
|
f"written with num_processes={saved_num_processes}. The data order resumes at the "
|
||||||
|
"right epoch/offset, but per-rank sample-exactness requires the same world size."
|
||||||
|
)
|
||||||
|
if is_main_process and saved_batch_size not in (None, cfg.batch_size):
|
||||||
|
logging.warning(
|
||||||
|
f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with "
|
||||||
|
f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, "
|
||||||
|
"but per-rank sample-exactness requires the same batch size."
|
||||||
|
)
|
||||||
|
sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes)
|
||||||
|
sampler.load_state_dict(sampler_state)
|
||||||
|
if is_main_process:
|
||||||
|
logging.info(
|
||||||
|
f"Resuming data order at epoch {sampler_state['epoch']}, "
|
||||||
|
f"sample {sampler_state['start_index']}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
shuffle = True
|
shuffle = True
|
||||||
sampler = None
|
sampler = None
|
||||||
@@ -521,6 +547,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
scheduler=lr_scheduler,
|
scheduler=lr_scheduler,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
|
num_processes=accelerator.num_processes,
|
||||||
|
batch_size=cfg.batch_size,
|
||||||
)
|
)
|
||||||
update_last_checkpoint(checkpoint_dir)
|
update_last_checkpoint(checkpoint_dir)
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
|
|||||||
@@ -114,28 +114,17 @@ def test_shuffle():
|
|||||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||||
|
|
||||||
|
|
||||||
def test_shuffle_with_generator_is_deterministic():
|
def test_shuffle_is_reproducible_across_instances():
|
||||||
# Two samplers shuffling with same-seed generators must yield identical permutations.
|
# The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks)
|
||||||
# This is what keeps batch shards disjoint across ranks in distributed training, where
|
# produce the same permutation without any generator synchronization.
|
||||||
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
|
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
epoch_0 = list(sampler_a)
|
||||||
assert list(sampler_a) == list(sampler_b)
|
assert list(sampler_b) == epoch_0
|
||||||
|
|
||||||
# Desyncing the global RNG must not affect the permutation.
|
# Desyncing the global RNG must not affect the permutation.
|
||||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||||
order_before = list(sampler_c)
|
|
||||||
sampler_c.generator.manual_seed(42)
|
|
||||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||||
assert list(sampler_c) == order_before
|
assert list(sampler_c) == epoch_0
|
||||||
|
|
||||||
|
|
||||||
def test_generator_attribute_defaults_to_none():
|
|
||||||
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
|
|
||||||
# so the attribute must exist even when no generator is passed.
|
|
||||||
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
|
|
||||||
assert sampler.generator is None
|
|
||||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
|
||||||
|
|
||||||
|
|
||||||
def test_negative_drop_first_frames_raises():
|
def test_negative_drop_first_frames_raises():
|
||||||
@@ -161,3 +150,87 @@ def test_partial_episode_drop_warns(caplog):
|
|||||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||||
assert sampler.indices == [2, 3, 4, 5]
|
assert sampler.indices == [2, 3, 4, 5]
|
||||||
assert "Episode 0" in caplog.text
|
assert "Episode 0" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
# --- seeded (seed, epoch) shuffling, resume, and state ---
|
||||||
|
|
||||||
|
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
|
||||||
|
|
||||||
|
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
|
||||||
|
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
|
||||||
|
for seed in (0, 1, 1234):
|
||||||
|
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
|
||||||
|
assert sorted(sampler) == list(range(num_frames))
|
||||||
|
|
||||||
|
|
||||||
|
def test_deterministic_sampler_epochs_reproduce_and_differ():
|
||||||
|
sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||||
|
sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||||
|
epoch_0 = list(sampler_a)
|
||||||
|
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
|
||||||
|
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
|
||||||
|
assert epoch_1 != epoch_0
|
||||||
|
assert sorted(epoch_1) == sorted(epoch_0)
|
||||||
|
sampler_a.set_epoch(0)
|
||||||
|
assert list(sampler_a) == epoch_0
|
||||||
|
assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
|
||||||
|
|
||||||
|
|
||||||
|
def test_deterministic_sampler_resume_mid_epoch():
|
||||||
|
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||||
|
epoch_0 = list(reference)
|
||||||
|
epoch_1 = list(reference)
|
||||||
|
for start in (0, 1, 4, len(epoch_0)):
|
||||||
|
resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||||
|
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||||
|
assert list(resumed) == epoch_0[start:]
|
||||||
|
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
|
||||||
|
assert list(resumed) == epoch_1
|
||||||
|
|
||||||
|
|
||||||
|
def test_deterministic_sampler_construction_stores_only_boundaries():
|
||||||
|
# Construction is O(num_episodes), not O(num_frames): a million-frame single episode
|
||||||
|
# instantiates from just its boundaries without materializing a per-frame index list.
|
||||||
|
num_frames = 1_000_000
|
||||||
|
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||||
|
assert len(sampler) == num_frames
|
||||||
|
assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deterministic_sampler_resume_is_exact_at_scale():
|
||||||
|
# Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's
|
||||||
|
# permutation and slicing from the saved offset reproduces the remaining order exactly.
|
||||||
|
num_frames = 100_000
|
||||||
|
reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||||
|
epoch_0 = list(reference)
|
||||||
|
assert sorted(epoch_0) == list(range(num_frames))
|
||||||
|
start = num_frames - 5
|
||||||
|
resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||||
|
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||||
|
assert list(resumed) == epoch_0[start:]
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_sampler_state():
|
||||||
|
# 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch.
|
||||||
|
assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == {
|
||||||
|
"epoch": 0,
|
||||||
|
"start_index": 0,
|
||||||
|
}
|
||||||
|
# step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in
|
||||||
|
assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == {
|
||||||
|
"epoch": 1,
|
||||||
|
"start_index": 40,
|
||||||
|
}
|
||||||
|
# uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank
|
||||||
|
assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == {
|
||||||
|
"epoch": 2,
|
||||||
|
"start_index": 40,
|
||||||
|
}
|
||||||
|
# uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads)
|
||||||
|
assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == {
|
||||||
|
"epoch": 1,
|
||||||
|
"start_index": 100,
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ from unittest.mock import Mock, patch
|
|||||||
from lerobot.common.train_utils import (
|
from lerobot.common.train_utils import (
|
||||||
get_step_checkpoint_dir,
|
get_step_checkpoint_dir,
|
||||||
get_step_identifier,
|
get_step_identifier,
|
||||||
|
load_training_batch_size,
|
||||||
|
load_training_num_processes,
|
||||||
load_training_state,
|
load_training_state,
|
||||||
load_training_step,
|
load_training_step,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
@@ -63,6 +65,28 @@ def test_load_training_step(tmp_path):
|
|||||||
assert loaded_step == step
|
assert loaded_step == step
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler):
|
||||||
|
save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4)
|
||||||
|
assert load_training_num_processes(tmp_path) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler):
|
||||||
|
# Checkpoints written before the world size was recorded must still load (back-compat).
|
||||||
|
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||||
|
assert load_training_num_processes(tmp_path) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler):
|
||||||
|
save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32)
|
||||||
|
assert load_training_batch_size(tmp_path) == 32
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler):
|
||||||
|
# Checkpoints written before the batch size was recorded must still load (back-compat).
|
||||||
|
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||||
|
assert load_training_batch_size(tmp_path) is None
|
||||||
|
|
||||||
|
|
||||||
def test_update_last_checkpoint(tmp_path):
|
def test_update_last_checkpoint(tmp_path):
|
||||||
checkpoint = tmp_path / "0005"
|
checkpoint = tmp_path / "0005"
|
||||||
checkpoint.mkdir()
|
checkpoint.mkdir()
|
||||||
|
|||||||
Reference in New Issue
Block a user