feat(dataset): 2x faster dataloader via parallel decode, uint8 transport, and persistent workers (#3406)

* feat(dataset): 2xfaster dataloader

* fix(dataset): streaming return uint8 decode

* fix(tests): adjust normalization step comparison

* fix(dataset): with threadexecutor + False default

* chore(dataset): make it a config

* fix(test): account for uint8 in training path testing
This commit is contained in:
Steven Palma
2026-04-19 00:08:22 +02:00
committed by GitHub
parent 760220d532
commit a8b72d9615
10 changed files with 78 additions and 18 deletions
+3
View File
@@ -35,6 +35,9 @@ class DatasetConfig:
revision: str | None = None revision: str | None = None
use_imagenet_stats: bool = True use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec) video_backend: str = field(default_factory=get_safe_default_codec)
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
streaming: bool = False streaming: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
+2
View File
@@ -56,6 +56,8 @@ class TrainPipelineConfig(HubMixin):
# Number of workers for the dataloader. # Number of workers for the dataloader.
num_workers: int = 4 num_workers: int = 4
batch_size: int = 8 batch_size: int = 8
prefetch_factor: int = 4
persistent_workers: bool = True
steps: int = 100_000 steps: int = 100_000
eval_freq: int = 20_000 eval_freq: int = 20_000
log_freq: int = 200 log_freq: int = 200
+23 -6
View File
@@ -16,6 +16,7 @@
"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding).""" """Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding)."""
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
import datasets import datasets
@@ -49,6 +50,7 @@ class DatasetReader:
video_backend: str, video_backend: str,
delta_timestamps: dict[str, list[float]] | None, delta_timestamps: dict[str, list[float]] | None,
image_transforms: Callable | None, image_transforms: Callable | None,
return_uint8: bool = False,
): ):
"""Initialize the reader with metadata, filtering, and transform config. """Initialize the reader with metadata, filtering, and transform config.
@@ -73,6 +75,7 @@ class DatasetReader:
self._tolerance_s = tolerance_s self._tolerance_s = tolerance_s
self._video_backend = video_backend self._video_backend = video_backend
self._image_transforms = image_transforms self._image_transforms = image_transforms
self._return_uint8 = return_uint8
self.hf_dataset: datasets.Dataset | None = None self.hf_dataset: datasets.Dataset | None = None
self._absolute_to_relative_idx: dict[int, int] | None = None self._absolute_to_relative_idx: dict[int, int] | None = None
@@ -233,16 +236,30 @@ class DatasetReader:
Segmentation Fault. Segmentation Fault.
""" """
ep = self._meta.episodes[ep_idx] ep = self._meta.episodes[ep_idx]
item = {}
for vid_key, query_ts in query_timestamps.items(): def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts] shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend) frames = decode_video_frames(
item[vid_key] = frames.squeeze(0) video_path,
shifted_query_ts,
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
)
return vid_key, frames.squeeze(0)
return item items = list(query_timestamps.items())
# Single camera: no threading overhead
if len(items) <= 1:
return {vid_key: _decode_single(vid_key, query_ts)[1] for vid_key, query_ts in items}
# Multi-camera: decode in parallel (video decoding releases the GIL)
with ThreadPoolExecutor(max_workers=len(items)) as pool:
futures = [pool.submit(_decode_single, k, ts) for k, ts in items]
return dict(f.result() for f in futures)
def get_item(self, idx) -> dict: def get_item(self, idx) -> dict:
"""Core __getitem__ logic. Assumes hf_dataset is loaded. """Core __getitem__ logic. Assumes hf_dataset is loaded.
+2
View File
@@ -92,6 +92,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
image_transforms=image_transforms, image_transforms=image_transforms,
revision=cfg.dataset.revision, revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend, video_backend=cfg.dataset.video_backend,
return_uint8=True,
tolerance_s=cfg.tolerance_s, tolerance_s=cfg.tolerance_s,
) )
else: else:
@@ -104,6 +105,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
revision=cfg.dataset.revision, revision=cfg.dataset.revision,
max_num_shards=cfg.num_workers, max_num_shards=cfg.num_workers,
tolerance_s=cfg.tolerance_s, tolerance_s=cfg.tolerance_s,
return_uint8=True,
) )
else: else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
+6
View File
@@ -56,6 +56,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False, force_cache_sync: bool = False,
download_videos: bool = True, download_videos: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
return_uint8: bool = False,
batch_encoding_size: int = 1, batch_encoding_size: int = 1,
vcodec: str = "libsvtav1", vcodec: str = "libsvtav1",
streaming_encoding: bool = False, streaming_encoding: bool = False,
@@ -202,6 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_codec() self._video_backend = video_backend if video_backend else get_safe_default_codec()
self._return_uint8 = return_uint8
self._batch_encoding_size = batch_encoding_size self._batch_encoding_size = batch_encoding_size
self._vcodec = resolve_vcodec(vcodec) self._vcodec = resolve_vcodec(vcodec)
self._encoder_threads = encoder_threads self._encoder_threads = encoder_threads
@@ -225,6 +227,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend=self._video_backend, video_backend=self._video_backend,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
image_transforms=image_transforms, image_transforms=image_transforms,
return_uint8=self._return_uint8,
) )
# Load actual data # Load actual data
@@ -288,6 +291,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend=self._video_backend, video_backend=self._video_backend,
delta_timestamps=self.delta_timestamps, delta_timestamps=self.delta_timestamps,
image_transforms=self.image_transforms, image_transforms=self.image_transforms,
return_uint8=self._return_uint8,
) )
return self.reader return self.reader
@@ -683,6 +687,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None obj.delta_timestamps = None
obj.episodes = None obj.episodes = None
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec() obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec obj._vcodec = vcodec
obj._encoder_threads = encoder_threads obj._encoder_threads = encoder_threads
@@ -775,6 +780,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None obj.delta_timestamps = None
obj.episodes = None obj.episodes = None
obj._video_backend = video_backend if video_backend else get_safe_default_codec() obj._video_backend = video_backend if video_backend else get_safe_default_codec()
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec obj._vcodec = vcodec
obj._encoder_threads = encoder_threads obj._encoder_threads = encoder_threads
+7 -1
View File
@@ -251,6 +251,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
seed: int = 42, seed: int = 42,
rng: np.random.Generator | None = None, rng: np.random.Generator | None = None,
shuffle: bool = True, shuffle: bool = True,
return_uint8: bool = False,
): ):
"""Initialize a StreamingLeRobotDataset. """Initialize a StreamingLeRobotDataset.
@@ -288,6 +289,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.streaming = streaming self.streaming = streaming
self.buffer_size = buffer_size self.buffer_size = buffer_size
self._return_uint8 = return_uint8
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None self.video_decoder_cache = None
@@ -553,7 +555,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}" video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
frames = decode_video_frames_torchcodec( frames = decode_video_frames_torchcodec(
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache video_path,
query_ts,
self.tolerance_s,
decoder_cache=self.video_decoder_cache,
return_uint8=self._return_uint8,
) )
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
+22 -10
View File
@@ -123,6 +123,7 @@ def decode_video_frames(
timestamps: list[float], timestamps: list[float],
tolerance_s: float, tolerance_s: float,
backend: str | None = None, backend: str | None = None,
return_uint8: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Decodes video frames using the specified backend. Decodes video frames using the specified backend.
@@ -131,19 +132,23 @@ def decode_video_frames(
video_path (Path): Path to the video file. video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames. timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval. tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".. backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
Returns: Returns:
torch.Tensor: Decoded frames. torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
Currently supports torchcodec on cpu and pyav. Currently supports torchcodec on cpu and pyav.
""" """
if backend is None: if backend is None:
backend = get_safe_default_codec() backend = get_safe_default_codec()
if backend == "torchcodec": if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend in ["pyav", "video_reader"]: elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
)
else: else:
raise ValueError(f"Unsupported video backend: {backend}") raise ValueError(f"Unsupported video backend: {backend}")
@@ -154,6 +159,7 @@ def decode_video_frames_torchvision(
tolerance_s: float, tolerance_s: float,
backend: str = "pyav", backend: str = "pyav",
log_loaded_timestamps: bool = False, log_loaded_timestamps: bool = False,
return_uint8: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video """Loads frames associated to the requested timestamps of a video
@@ -240,14 +246,17 @@ def decode_video_frames_torchvision(
if log_loaded_timestamps: if log_loaded_timestamps:
logger.info(f"{closest_ts=}") logger.info(f"{closest_ts=}")
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
if len(timestamps) != len(closest_frames): if len(timestamps) != len(closest_frames):
raise FrameTimestampError( raise FrameTimestampError(
f"Number of retrieved frames ({len(closest_frames)}) does not match " f"Number of retrieved frames ({len(closest_frames)}) does not match "
f"number of queried timestamps ({len(timestamps)})" f"number of queried timestamps ({len(timestamps)})"
) )
if return_uint8:
return closest_frames
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
return closest_frames return closest_frames
@@ -306,6 +315,7 @@ def decode_video_frames_torchcodec(
tolerance_s: float, tolerance_s: float,
log_loaded_timestamps: bool = False, log_loaded_timestamps: bool = False,
decoder_cache: VideoDecoderCache | None = None, decoder_cache: VideoDecoderCache | None = None,
return_uint8: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec. """Loads frames associated with the requested timestamps of a video using torchcodec.
@@ -373,14 +383,16 @@ def decode_video_frames_torchcodec(
if log_loaded_timestamps: if log_loaded_timestamps:
logger.info(f"{closest_ts=}") logger.info(f"{closest_ts=}")
# convert to float32 in [0,1] range
closest_frames = (closest_frames / 255.0).type(torch.float32)
if not len(timestamps) == len(closest_frames): if not len(timestamps) == len(closest_frames):
raise FrameTimestampError( raise FrameTimestampError(
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}" f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
) )
if return_uint8:
return closest_frames
# convert to float32 in [0,1] range
closest_frames = (closest_frames / 255.0).type(torch.float32)
return closest_frames return closest_frames
+5 -1
View File
@@ -386,7 +386,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
sampler=sampler, sampler=sampler,
pin_memory=device.type == "cuda", pin_memory=device.type == "cuda",
drop_last=False, drop_last=False,
prefetch_factor=2 if cfg.num_workers > 0 else None, prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
) )
# Prepare everything with accelerator # Prepare everything with accelerator
@@ -433,6 +434,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
for _ in range(step, cfg.steps): for _ in range(step, cfg.steps):
start_time = time.perf_counter() start_time = time.perf_counter()
batch = next(dl_iter) batch = next(dl_iter)
for cam_key in dataset.meta.camera_keys:
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
batch = preprocessor(batch) batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time train_tracker.dataloading_s = time.perf_counter() - start_time
@@ -52,6 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
) )
batch = next(iter(dataloader)) batch = next(iter(dataloader))
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
batch = preprocessor(batch) batch = preprocessor(batch)
loss, output_dict = policy.forward(batch) loss, output_dict = policy.forward(batch)
@@ -82,6 +85,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# indicating padding (those ending with "_is_pad") # indicating padding (those ending with "_is_pad")
dataset.reader.delta_indices = None dataset.reader.delta_indices = None
batch = next(iter(dataloader)) batch = next(iter(dataloader))
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
obs = {} obs = {}
for k in batch: for k in batch:
# TODO: regenerate the safetensors # TODO: regenerate the safetensors
+2
View File
@@ -196,6 +196,8 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
for key in batch: for key in batch:
if isinstance(batch[key], torch.Tensor): if isinstance(batch[key], torch.Tensor):
if batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
batch[key] = batch[key].to(DEVICE, non_blocking=True) batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy (and test that it does not mutate the batch) # Test updating the policy (and test that it does not mutate the batch)