From 971d7ea85b775e9933323f31a30afee5100ba153 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 12 Jun 2026 18:04:45 +0200 Subject: [PATCH] fix(is_depth): adding missing doctrings and is_depth arguments in video decoding functions Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com> --- src/lerobot/datasets/video_utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 2664c7ca0..6607b27d5 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -70,28 +70,29 @@ def decode_video_frames( backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is accepted for one release as an alias for "pyav" and will be removed in a future version. - return_uint8 (bool): If True, return raw uint8 frames without float32 normalization. + return_uint8 (bool): For RGB videos, if True return raw uint8 frames without float32 normalization. This reduces memory for DataLoader IPC; normalization can be done on GPU afterward. + is_depth (bool): Set to True if the video is a depth map (1 channel, uint12). Returns: - torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True). + torch.Tensor: Decoded frames (RGB: float32 in [0,1] by default, or uint8 if return_uint8=True, Depth: uint12). Currently supports torchcodec on cpu and pyav. """ if backend != "pyav" and is_depth: logger.warning("Decoding depth maps is only supported with the 'pyav' backend.") # We do not actually return uint8 here, but we avoid the 255 normalization step. - return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=True, is_depth=True) + return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True) if backend is None: backend = get_safe_default_video_backend() if backend == "torchcodec": return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8) elif backend == "pyav": - return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8) + return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth) elif backend == "video_reader": logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.") - return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8) + return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth) else: raise ValueError(f"Unsupported video backend: {backend}") @@ -121,8 +122,9 @@ def decode_video_frames_pyav( tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest decoded frame. log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level. - return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in - [0, 1] range. + return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W). + Otherwise, return float32 in [0, 1] range. + is_depth: Set to True if the video is a depth map (1 channel, uint12). Returns: torch.Tensor of shape (len(timestamps), C, H, W). @@ -201,7 +203,7 @@ def decode_video_frames_pyav( f"number of queried timestamps ({len(timestamps)})" ) - if return_uint8: + if return_uint8 or is_depth: return closest_frames # convert to the pytorch format which is float32 in [0,1] range (and channel first)