mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
feat(depth): wire DatasetReader to decode_depth_frames
This commit is contained in:
@@ -32,7 +32,13 @@ from .io_utils import (
|
||||
hf_transform_to_torch,
|
||||
load_nested_dataset,
|
||||
)
|
||||
from .video_utils import decode_video_frames
|
||||
from .video_utils import decode_depth_frames, decode_video_frames
|
||||
from .depth_utils import (
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
)
|
||||
|
||||
|
||||
class DatasetReader:
|
||||
@@ -237,17 +243,31 @@ class DatasetReader:
|
||||
"""
|
||||
ep = self._meta.episodes[ep_idx]
|
||||
|
||||
depth_keys = set(self._meta.depth_keys)
|
||||
|
||||
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
|
||||
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
if vid_key in depth_keys:
|
||||
feature_info = self._meta.features[vid_key].get("info") or {}
|
||||
frames = decode_depth_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
depth_min=feature_info.get("video.depth_min", DEFAULT_DEPTH_MIN),
|
||||
depth_max=feature_info.get("video.depth_max", DEFAULT_DEPTH_MAX),
|
||||
shift=feature_info.get("video.shift", DEFAULT_DEPTH_SHIFT),
|
||||
use_log=feature_info.get("video.use_log", DEFAULT_DEPTH_USE_LOG),
|
||||
)
|
||||
else:
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
@@ -1483,7 +1483,8 @@ def test_valid_video_codecs_constant():
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 10
|
||||
assert "ffv1" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 11
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
Reference in New Issue
Block a user