mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
feat(depth): wire DatasetReader to decode_depth_frames
This commit is contained in:
@@ -23,6 +23,7 @@ import datasets
|
||||
import torch
|
||||
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .depth_utils import dequantize_depth
|
||||
from .feature_utils import (
|
||||
check_delta_timestamps,
|
||||
get_delta_indices,
|
||||
@@ -86,6 +87,18 @@ class DatasetReader:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
if self._meta.depth_keys:
|
||||
# TODO(CarolinePascal): make this decent, this is awful.
|
||||
self._dequantize_depth_configs = {
|
||||
vid_key: {
|
||||
"depth_min": self._meta.features[vid_key]["info"]["video.depth_min"],
|
||||
"depth_max": self._meta.features[vid_key]["info"]["video.depth_max"],
|
||||
"shift": self._meta.features[vid_key]["info"]["video.shift"],
|
||||
"use_log": self._meta.features[vid_key]["info"]["video.use_log"],
|
||||
}
|
||||
for vid_key in self._meta.depth_keys
|
||||
}
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
try:
|
||||
@@ -247,7 +260,12 @@ class DatasetReader:
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
is_depth=vid_key in self._meta.depth_keys,
|
||||
)
|
||||
if vid_key in self._meta.depth_keys:
|
||||
frames = dequantize_depth(
|
||||
frames, **self._dequantize_depth_configs[vid_key], output_tensor=True
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
@@ -44,7 +44,7 @@ from lerobot.configs import (
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
from .depth_utils import quantize_depth
|
||||
from .depth_utils import DEPTH_PIX_FMT, quantize_depth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,6 +55,7 @@ def decode_video_frames(
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
@@ -74,6 +75,11 @@ def decode_video_frames(
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend != "pyav" and is_depth:
|
||||
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
|
||||
# We do not actually return uint8 here, but we avoid the 255 normalization step.
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=True, is_depth=True)
|
||||
|
||||
if backend is None:
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
@@ -93,6 +99,7 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video using PyAV.
|
||||
|
||||
@@ -142,9 +149,13 @@ def decode_video_frames_pyav(
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
arr = frame.to_ndarray(format="rgb24") # H, W, 3
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
if is_depth:
|
||||
arr = frame.to_ndarray(format=DEPTH_PIX_FMT) # (H, W) uint16
|
||||
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
|
||||
else:
|
||||
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user