feat(depth): wire DatasetReader to decode_depth_frames

This commit is contained in:
CarolinePascal
2026-05-19 23:46:28 +02:00
parent d39698da0f
commit e51d45dd2c
2 changed files with 33 additions and 4 deletions
+18
View File
@@ -23,6 +23,7 @@ import datasets
import torch
from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import dequantize_depth
from .feature_utils import (
check_delta_timestamps,
get_delta_indices,
@@ -86,6 +87,18 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
if self._meta.depth_keys:
# TODO(CarolinePascal): make this decent, this is awful.
self._dequantize_depth_configs = {
vid_key: {
"depth_min": self._meta.features[vid_key]["info"]["video.depth_min"],
"depth_max": self._meta.features[vid_key]["info"]["video.depth_max"],
"shift": self._meta.features[vid_key]["info"]["video.shift"],
"use_log": self._meta.features[vid_key]["info"]["video.use_log"],
}
for vid_key in self._meta.depth_keys
}
def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient."""
try:
@@ -247,7 +260,12 @@ class DatasetReader:
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
is_depth=vid_key in self._meta.depth_keys,
)
if vid_key in self._meta.depth_keys:
frames = dequantize_depth(
frames, **self._dequantize_depth_configs[vid_key], output_tensor=True
)
return vid_key, frames.squeeze(0)
items = list(query_timestamps.items())
+15 -4
View File
@@ -44,7 +44,7 @@ from lerobot.configs import (
)
from lerobot.utils.import_utils import get_safe_default_video_backend
from .depth_utils import quantize_depth
from .depth_utils import DEPTH_PIX_FMT, quantize_depth
logger = logging.getLogger(__name__)
@@ -55,6 +55,7 @@ def decode_video_frames(
tolerance_s: float,
backend: str | None = None,
return_uint8: bool = False,
is_depth: bool = False,
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
@@ -74,6 +75,11 @@ def decode_video_frames(
Currently supports torchcodec on cpu and pyav.
"""
if backend != "pyav" and is_depth:
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
# We do not actually return uint8 here, but we avoid the 255 normalization step.
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=True, is_depth=True)
if backend is None:
backend = get_safe_default_video_backend()
if backend == "torchcodec":
@@ -93,6 +99,7 @@ def decode_video_frames_pyav(
tolerance_s: float,
log_loaded_timestamps: bool = False,
return_uint8: bool = False,
is_depth: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video using PyAV.
@@ -142,9 +149,13 @@ def decode_video_frames_pyav(
current_ts = float(frame.pts * stream.time_base)
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
# Convert to CHW uint8 to match torchcodec's output layout.
arr = frame.to_ndarray(format="rgb24") # H, W, 3
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
if is_depth:
arr = frame.to_ndarray(format=DEPTH_PIX_FMT) # (H, W) uint16
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
else:
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
# Convert to CHW uint8 to match torchcodec's output layout.
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break