mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
feat(dataset): 2x faster dataloader via parallel decode, uint8 transport, and persistent workers (#3406)
* feat(dataset): 2xfaster dataloader * fix(dataset): streaming return uint8 decode * fix(tests): adjust normalization step comparison * fix(dataset): with threadexecutor + False default * chore(dataset): make it a config * fix(test): account for uint8 in training path testing
This commit is contained in:
@@ -35,6 +35,9 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
@@ -56,6 +56,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
# Number of workers for the dataloader.
|
||||
num_workers: int = 4
|
||||
batch_size: int = 8
|
||||
prefetch_factor: int = 4
|
||||
persistent_workers: bool = True
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding)."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -49,6 +50,7 @@ class DatasetReader:
|
||||
video_backend: str,
|
||||
delta_timestamps: dict[str, list[float]] | None,
|
||||
image_transforms: Callable | None,
|
||||
return_uint8: bool = False,
|
||||
):
|
||||
"""Initialize the reader with metadata, filtering, and transform config.
|
||||
|
||||
@@ -73,6 +75,7 @@ class DatasetReader:
|
||||
self._tolerance_s = tolerance_s
|
||||
self._video_backend = video_backend
|
||||
self._image_transforms = image_transforms
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
self.hf_dataset: datasets.Dataset | None = None
|
||||
self._absolute_to_relative_idx: dict[int, int] | None = None
|
||||
@@ -233,16 +236,30 @@ class DatasetReader:
|
||||
Segmentation Fault.
|
||||
"""
|
||||
ep = self._meta.episodes[ep_idx]
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
|
||||
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
|
||||
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
return item
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
# Single camera: no threading overhead
|
||||
if len(items) <= 1:
|
||||
return {vid_key: _decode_single(vid_key, query_ts)[1] for vid_key, query_ts in items}
|
||||
|
||||
# Multi-camera: decode in parallel (video decoding releases the GIL)
|
||||
with ThreadPoolExecutor(max_workers=len(items)) as pool:
|
||||
futures = [pool.submit(_decode_single, k, ts) for k, ts in items]
|
||||
return dict(f.result() for f in futures)
|
||||
|
||||
def get_item(self, idx) -> dict:
|
||||
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
|
||||
|
||||
@@ -92,6 +92,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
return_uint8=True,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
@@ -104,6 +105,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
|
||||
@@ -56,6 +56,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
@@ -202,6 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self._return_uint8 = return_uint8
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
@@ -225,6 +227,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend=self._video_backend,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
|
||||
# Load actual data
|
||||
@@ -288,6 +291,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend=self._video_backend,
|
||||
delta_timestamps=self.delta_timestamps,
|
||||
image_transforms=self.image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
return self.reader
|
||||
|
||||
@@ -683,6 +687,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
@@ -775,6 +780,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
@@ -251,6 +251,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
seed: int = 42,
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
return_uint8: bool = False,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -288,6 +289,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
@@ -553,7 +555,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
||||
video_path,
|
||||
query_ts,
|
||||
self.tolerance_s,
|
||||
decoder_cache=self.video_decoder_cache,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
|
||||
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
||||
|
||||
@@ -123,6 +123,7 @@ def decode_video_frames(
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
@@ -131,19 +132,23 @@ def decode_video_frames(
|
||||
video_path (Path): Path to the video file.
|
||||
timestamps (list[float]): List of timestamps to extract frames.
|
||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
|
||||
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames.
|
||||
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_codec()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
@@ -154,6 +159,7 @@ def decode_video_frames_torchvision(
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video
|
||||
|
||||
@@ -240,14 +246,17 @@ def decode_video_frames_torchvision(
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"{closest_ts=}")
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
if len(timestamps) != len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
|
||||
if return_uint8:
|
||||
return closest_frames
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -306,6 +315,7 @@ def decode_video_frames_torchcodec(
|
||||
tolerance_s: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
decoder_cache: VideoDecoderCache | None = None,
|
||||
return_uint8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
@@ -373,14 +383,16 @@ def decode_video_frames_torchcodec(
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range
|
||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||
|
||||
if not len(timestamps) == len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
|
||||
)
|
||||
|
||||
if return_uint8:
|
||||
return closest_frames
|
||||
|
||||
# convert to float32 in [0,1] range
|
||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||
return closest_frames
|
||||
|
||||
|
||||
|
||||
@@ -386,7 +386,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||
)
|
||||
|
||||
# Prepare everything with accelerator
|
||||
@@ -433,6 +434,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
|
||||
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||
batch = preprocessor(batch)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
|
||||
@@ -52,6 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
batch = preprocessor(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
@@ -82,6 +85,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
# indicating padding (those ending with "_is_pad")
|
||||
dataset.reader.delta_indices = None
|
||||
batch = next(iter(dataloader))
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
obs = {}
|
||||
for k in batch:
|
||||
# TODO: regenerate the safetensors
|
||||
|
||||
@@ -196,6 +196,8 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
if batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
|
||||
Reference in New Issue
Block a user