diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index b05e96fde..be906edbd 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -35,6 +35,9 @@ class DatasetConfig: revision: str | None = None use_imagenet_stats: bool = True video_backend: str = field(default_factory=get_safe_default_codec) + # When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0). + # This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion. + return_uint8: bool = False streaming: bool = False def __post_init__(self) -> None: diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index d754a0847..924bcf5bb 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -56,6 +56,8 @@ class TrainPipelineConfig(HubMixin): # Number of workers for the dataloader. num_workers: int = 4 batch_size: int = 8 + prefetch_factor: int = 4 + persistent_workers: bool = True steps: int = 100_000 eval_freq: int = 20_000 log_freq: int = 200 diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index 718b33b12..bd1298590 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -16,6 +16,7 @@ """Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding).""" from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor from pathlib import Path import datasets @@ -49,6 +50,7 @@ class DatasetReader: video_backend: str, delta_timestamps: dict[str, list[float]] | None, image_transforms: Callable | None, + return_uint8: bool = False, ): """Initialize the reader with metadata, filtering, and transform config. @@ -73,6 +75,7 @@ class DatasetReader: self._tolerance_s = tolerance_s self._video_backend = video_backend self._image_transforms = image_transforms + self._return_uint8 = return_uint8 self.hf_dataset: datasets.Dataset | None = None self._absolute_to_relative_idx: dict[int, int] | None = None @@ -233,16 +236,30 @@ class DatasetReader: Segmentation Fault. """ ep = self._meta.episodes[ep_idx] - item = {} - for vid_key, query_ts in query_timestamps.items(): + + def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]: from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] shifted_query_ts = [from_timestamp + ts for ts in query_ts] - video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend) - item[vid_key] = frames.squeeze(0) + frames = decode_video_frames( + video_path, + shifted_query_ts, + self._tolerance_s, + self._video_backend, + return_uint8=self._return_uint8, + ) + return vid_key, frames.squeeze(0) - return item + items = list(query_timestamps.items()) + + # Single camera: no threading overhead + if len(items) <= 1: + return {vid_key: _decode_single(vid_key, query_ts)[1] for vid_key, query_ts in items} + + # Multi-camera: decode in parallel (video decoding releases the GIL) + with ThreadPoolExecutor(max_workers=len(items)) as pool: + futures = [pool.submit(_decode_single, k, ts) for k, ts in items] + return dict(f.result() for f in futures) def get_item(self, idx) -> dict: """Core __getitem__ logic. Assumes hf_dataset is loaded. diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 040cba5cb..73df3f04b 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -92,6 +92,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas image_transforms=image_transforms, revision=cfg.dataset.revision, video_backend=cfg.dataset.video_backend, + return_uint8=True, tolerance_s=cfg.tolerance_s, ) else: @@ -104,6 +105,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas revision=cfg.dataset.revision, max_num_shards=cfg.num_workers, tolerance_s=cfg.tolerance_s, + return_uint8=True, ) else: raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 7cda5d677..644ce14db 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -56,6 +56,7 @@ class LeRobotDataset(torch.utils.data.Dataset): force_cache_sync: bool = False, download_videos: bool = True, video_backend: str | None = None, + return_uint8: bool = False, batch_encoding_size: int = 1, vcodec: str = "libsvtav1", streaming_encoding: bool = False, @@ -202,6 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION self._video_backend = video_backend if video_backend else get_safe_default_codec() + self._return_uint8 = return_uint8 self._batch_encoding_size = batch_encoding_size self._vcodec = resolve_vcodec(vcodec) self._encoder_threads = encoder_threads @@ -225,6 +227,7 @@ class LeRobotDataset(torch.utils.data.Dataset): video_backend=self._video_backend, delta_timestamps=delta_timestamps, image_transforms=image_transforms, + return_uint8=self._return_uint8, ) # Load actual data @@ -288,6 +291,7 @@ class LeRobotDataset(torch.utils.data.Dataset): video_backend=self._video_backend, delta_timestamps=self.delta_timestamps, image_transforms=self.image_transforms, + return_uint8=self._return_uint8, ) return self.reader @@ -683,6 +687,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_timestamps = None obj.episodes = None obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj._return_uint8 = False obj._batch_encoding_size = batch_encoding_size obj._vcodec = vcodec obj._encoder_threads = encoder_threads @@ -775,6 +780,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.delta_timestamps = None obj.episodes = None obj._video_backend = video_backend if video_backend else get_safe_default_codec() + obj._return_uint8 = False obj._batch_encoding_size = batch_encoding_size obj._vcodec = vcodec obj._encoder_threads = encoder_threads diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index f47d71367..4de2ed69c 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -251,6 +251,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): seed: int = 42, rng: np.random.Generator | None = None, shuffle: bool = True, + return_uint8: bool = False, ): """Initialize a StreamingLeRobotDataset. @@ -288,6 +289,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.streaming = streaming self.buffer_size = buffer_size + self._return_uint8 = return_uint8 # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None @@ -553,7 +555,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}" frames = decode_video_frames_torchcodec( - video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache + video_path, + query_ts, + self.tolerance_s, + decoder_cache=self.video_decoder_cache, + return_uint8=self._return_uint8, ) item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index cabe592d0..158e68cdb 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -123,6 +123,7 @@ def decode_video_frames( timestamps: list[float], tolerance_s: float, backend: str | None = None, + return_uint8: bool = False, ) -> torch.Tensor: """ Decodes video frames using the specified backend. @@ -131,19 +132,23 @@ def decode_video_frames( video_path (Path): Path to the video file. timestamps (list[float]): List of timestamps to extract frames. tolerance_s (float): Allowed deviation in seconds for frame retrieval. - backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".. + backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav". + return_uint8 (bool): If True, return raw uint8 frames without float32 normalization. + This reduces memory for DataLoader IPC; normalization can be done on GPU afterward. Returns: - torch.Tensor: Decoded frames. + torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True). Currently supports torchcodec on cpu and pyav. """ if backend is None: backend = get_safe_default_codec() if backend == "torchcodec": - return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) + return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8) elif backend in ["pyav", "video_reader"]: - return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + return decode_video_frames_torchvision( + video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8 + ) else: raise ValueError(f"Unsupported video backend: {backend}") @@ -154,6 +159,7 @@ def decode_video_frames_torchvision( tolerance_s: float, backend: str = "pyav", log_loaded_timestamps: bool = False, + return_uint8: bool = False, ) -> torch.Tensor: """Loads frames associated to the requested timestamps of a video @@ -240,14 +246,17 @@ def decode_video_frames_torchvision( if log_loaded_timestamps: logger.info(f"{closest_ts=}") - # convert to the pytorch format which is float32 in [0,1] range (and channel first) - closest_frames = closest_frames.type(torch.float32) / 255 - if len(timestamps) != len(closest_frames): raise FrameTimestampError( f"Number of retrieved frames ({len(closest_frames)}) does not match " f"number of queried timestamps ({len(timestamps)})" ) + + if return_uint8: + return closest_frames + + # convert to the pytorch format which is float32 in [0,1] range (and channel first) + closest_frames = closest_frames.type(torch.float32) / 255 return closest_frames @@ -306,6 +315,7 @@ def decode_video_frames_torchcodec( tolerance_s: float, log_loaded_timestamps: bool = False, decoder_cache: VideoDecoderCache | None = None, + return_uint8: bool = False, ) -> torch.Tensor: """Loads frames associated with the requested timestamps of a video using torchcodec. @@ -373,14 +383,16 @@ def decode_video_frames_torchcodec( if log_loaded_timestamps: logger.info(f"{closest_ts=}") - # convert to float32 in [0,1] range - closest_frames = (closest_frames / 255.0).type(torch.float32) - if not len(timestamps) == len(closest_frames): raise FrameTimestampError( f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}" ) + if return_uint8: + return closest_frames + + # convert to float32 in [0,1] range + closest_frames = (closest_frames / 255.0).type(torch.float32) return closest_frames diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index a862c640d..856006507 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -386,7 +386,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): sampler=sampler, pin_memory=device.type == "cuda", drop_last=False, - prefetch_factor=2 if cfg.num_workers > 0 else None, + prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None, + persistent_workers=cfg.persistent_workers and cfg.num_workers > 0, ) # Prepare everything with accelerator @@ -433,6 +434,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) + for cam_key in dataset.meta.camera_keys: + if cam_key in batch and batch[cam_key].dtype == torch.uint8: + batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0 batch = preprocessor(batch) train_tracker.dataloading_s = time.perf_counter() - start_time diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index ffb3efd03..158e3e0ef 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -52,6 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): ) batch = next(iter(dataloader)) + for key in batch: + if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8: + batch[key] = batch[key].to(dtype=torch.float32) / 255.0 batch = preprocessor(batch) loss, output_dict = policy.forward(batch) @@ -82,6 +85,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): # indicating padding (those ending with "_is_pad") dataset.reader.delta_indices = None batch = next(iter(dataloader)) + for key in batch: + if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8: + batch[key] = batch[key].to(dtype=torch.float32) / 255.0 obs = {} for k in batch: # TODO: regenerate the safetensors diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 2d50446fe..e9388b3ed 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -196,6 +196,8 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): for key in batch: if isinstance(batch[key], torch.Tensor): + if batch[key].dtype == torch.uint8: + batch[key] = batch[key].to(dtype=torch.float32) / 255.0 batch[key] = batch[key].to(DEVICE, non_blocking=True) # Test updating the policy (and test that it does not mutate the batch)