From e51d45dd2c12e79ccfb523a7dddb747b272cc3ee Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Tue, 19 May 2026 23:46:28 +0200 Subject: [PATCH] feat(depth): wire DatasetReader to decode_depth_frames --- src/lerobot/datasets/dataset_reader.py | 18 ++++++++++++++++++ src/lerobot/datasets/video_utils.py | 19 +++++++++++++++---- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index bd1298590..e6419c18e 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -23,6 +23,7 @@ import datasets import torch from .dataset_metadata import LeRobotDatasetMetadata +from .depth_utils import dequantize_depth from .feature_utils import ( check_delta_timestamps, get_delta_indices, @@ -86,6 +87,18 @@ class DatasetReader: check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) + if self._meta.depth_keys: + # TODO(CarolinePascal): make this decent, this is awful. + self._dequantize_depth_configs = { + vid_key: { + "depth_min": self._meta.features[vid_key]["info"]["video.depth_min"], + "depth_max": self._meta.features[vid_key]["info"]["video.depth_max"], + "shift": self._meta.features[vid_key]["info"]["video.shift"], + "use_log": self._meta.features[vid_key]["info"]["video.use_log"], + } + for vid_key in self._meta.depth_keys + } + def try_load(self) -> bool: """Attempt to load from local cache. Returns True if data is sufficient.""" try: @@ -247,7 +260,12 @@ class DatasetReader: self._tolerance_s, self._video_backend, return_uint8=self._return_uint8, + is_depth=vid_key in self._meta.depth_keys, ) + if vid_key in self._meta.depth_keys: + frames = dequantize_depth( + frames, **self._dequantize_depth_configs[vid_key], output_tensor=True + ) return vid_key, frames.squeeze(0) items = list(query_timestamps.items()) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index dd7cf7ee7..6330ae447 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -44,7 +44,7 @@ from lerobot.configs import ( ) from lerobot.utils.import_utils import get_safe_default_video_backend -from .depth_utils import quantize_depth +from .depth_utils import DEPTH_PIX_FMT, quantize_depth logger = logging.getLogger(__name__) @@ -55,6 +55,7 @@ def decode_video_frames( tolerance_s: float, backend: str | None = None, return_uint8: bool = False, + is_depth: bool = False, ) -> torch.Tensor: """ Decodes video frames using the specified backend. @@ -74,6 +75,11 @@ def decode_video_frames( Currently supports torchcodec on cpu and pyav. """ + if backend != "pyav" and is_depth: + logger.warning("Decoding depth maps is only supported with the 'pyav' backend.") + # We do not actually return uint8 here, but we avoid the 255 normalization step. + return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=True, is_depth=True) + if backend is None: backend = get_safe_default_video_backend() if backend == "torchcodec": @@ -93,6 +99,7 @@ def decode_video_frames_pyav( tolerance_s: float, log_loaded_timestamps: bool = False, return_uint8: bool = False, + is_depth: bool = False, ) -> torch.Tensor: """Loads frames associated to the requested timestamps of a video using PyAV. @@ -142,9 +149,13 @@ def decode_video_frames_pyav( current_ts = float(frame.pts * stream.time_base) if log_loaded_timestamps: logger.info(f"frame loaded at timestamp={current_ts:.4f}") - # Convert to CHW uint8 to match torchcodec's output layout. - arr = frame.to_ndarray(format="rgb24") # H, W, 3 - loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous()) + if is_depth: + arr = frame.to_ndarray(format=DEPTH_PIX_FMT) # (H, W) uint16 + loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous()) + else: + arr = frame.to_ndarray(format="rgb24") # (H, W, 3) + # Convert to CHW uint8 to match torchcodec's output layout. + loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous()) loaded_ts.append(current_ts) if current_ts >= last_ts: break