From 2744e265931839b475f5f2247a2b197a17013a91 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 27 Apr 2026 16:09:58 +0200 Subject: [PATCH] feat(depth): wire DatasetReader to decode_depth_frames --- src/lerobot/datasets/dataset_reader.py | 36 ++++++++++++++++++++------ tests/datasets/test_datasets.py | 3 ++- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index bd1298590..bd5c4a8e9 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -32,7 +32,13 @@ from .io_utils import ( hf_transform_to_torch, load_nested_dataset, ) -from .video_utils import decode_video_frames +from .video_utils import decode_depth_frames, decode_video_frames +from .depth_utils import ( + DEFAULT_DEPTH_MIN, + DEFAULT_DEPTH_MAX, + DEFAULT_DEPTH_SHIFT, + DEFAULT_DEPTH_USE_LOG, +) class DatasetReader: @@ -237,17 +243,31 @@ class DatasetReader: """ ep = self._meta.episodes[ep_idx] + depth_keys = set(self._meta.depth_keys) + def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]: from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] shifted_query_ts = [from_timestamp + ts for ts in query_ts] video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames( - video_path, - shifted_query_ts, - self._tolerance_s, - self._video_backend, - return_uint8=self._return_uint8, - ) + if vid_key in depth_keys: + feature_info = self._meta.features[vid_key].get("info") or {} + frames = decode_depth_frames( + video_path, + shifted_query_ts, + self._tolerance_s, + depth_min=feature_info.get("video.depth_min", DEFAULT_DEPTH_MIN), + depth_max=feature_info.get("video.depth_max", DEFAULT_DEPTH_MAX), + shift=feature_info.get("video.shift", DEFAULT_DEPTH_SHIFT), + use_log=feature_info.get("video.use_log", DEFAULT_DEPTH_USE_LOG), + ) + else: + frames = decode_video_frames( + video_path, + shifted_query_ts, + self._tolerance_s, + self._video_backend, + return_uint8=self._return_uint8, + ) return vid_key, frames.squeeze(0) items = list(query_timestamps.items()) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 90a1aa5dc..a9af3221a 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1483,7 +1483,8 @@ def test_valid_video_codecs_constant(): assert "auto" in VALID_VIDEO_CODECS assert "h264_videotoolbox" in VALID_VIDEO_CODECS assert "h264_nvenc" in VALID_VIDEO_CODECS - assert len(VALID_VIDEO_CODECS) == 10 + assert "ffv1" in VALID_VIDEO_CODECS + assert len(VALID_VIDEO_CODECS) == 11 def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):