feat(depth): wire DatasetReader to decode_depth_frames

This commit is contained in:
CarolinePascal
2026-04-27 16:09:58 +02:00
parent de64ad3f7e
commit 2744e26593
2 changed files with 30 additions and 9 deletions
+28 -8
View File
@@ -32,7 +32,13 @@ from .io_utils import (
hf_transform_to_torch,
load_nested_dataset,
)
from .video_utils import decode_video_frames
from .video_utils import decode_depth_frames, decode_video_frames
from .depth_utils import (
DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
)
class DatasetReader:
@@ -237,17 +243,31 @@ class DatasetReader:
"""
ep = self._meta.episodes[ep_idx]
depth_keys = set(self._meta.depth_keys)
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(
video_path,
shifted_query_ts,
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
)
if vid_key in depth_keys:
feature_info = self._meta.features[vid_key].get("info") or {}
frames = decode_depth_frames(
video_path,
shifted_query_ts,
self._tolerance_s,
depth_min=feature_info.get("video.depth_min", DEFAULT_DEPTH_MIN),
depth_max=feature_info.get("video.depth_max", DEFAULT_DEPTH_MAX),
shift=feature_info.get("video.shift", DEFAULT_DEPTH_SHIFT),
use_log=feature_info.get("video.use_log", DEFAULT_DEPTH_USE_LOG),
)
else:
frames = decode_video_frames(
video_path,
shifted_query_ts,
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
)
return vid_key, frames.squeeze(0)
items = list(query_timestamps.items())
+2 -1
View File
@@ -1483,7 +1483,8 @@ def test_valid_video_codecs_constant():
assert "auto" in VALID_VIDEO_CODECS
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
assert "h264_nvenc" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 10
assert "ffv1" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 11
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):