mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-29 06:07:40 +00:00
fix(streaming): adding support for dequantization in streaming_dataset.py
This commit is contained in:
@@ -22,9 +22,11 @@ import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.configs import DEFAULT_DEPTH_UNIT, DepthEncoderConfig
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from .depth_utils import dequantize_depth
|
||||
from .feature_utils import get_delta_indices
|
||||
from .io_utils import item_to_torch
|
||||
from .utils import (
|
||||
@@ -35,6 +37,7 @@ from .utils import (
|
||||
)
|
||||
from .video_utils import (
|
||||
VideoDecoderCache,
|
||||
decode_video_frames,
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
|
||||
@@ -252,6 +255,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
return_uint8: bool = False,
|
||||
depth_output_unit: str = DEFAULT_DEPTH_UNIT,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -272,6 +276,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
seed (int, optional): Reproducibility random seed.
|
||||
rng (np.random.Generator | None, optional): Random number generator.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
||||
depth_output_unit (str, optional): Physical unit depth maps are dequantized to ("m" or "mm").
|
||||
Defaults to "mm".
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -290,6 +296,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
self._return_uint8 = return_uint8
|
||||
self._depth_output_unit = depth_output_unit
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
@@ -306,6 +313,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
|
||||
vid_key: DepthEncoderConfig.from_video_info(self.meta.features[vid_key].get("info"))
|
||||
for vid_key in self.meta.depth_keys
|
||||
}
|
||||
|
||||
self.delta_timestamps = None
|
||||
self.delta_indices = None
|
||||
|
||||
@@ -554,13 +566,34 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
for video_key, query_ts in query_timestamps.items():
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
query_ts,
|
||||
self.tolerance_s,
|
||||
decoder_cache=self.video_decoder_cache,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
if video_key in self.meta.depth_keys:
|
||||
# Depth maps are 12-bit quantized and only decodable via pyav; dequantize back
|
||||
# to physical units to match the non-streaming reader.
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
query_ts,
|
||||
self.tolerance_s,
|
||||
backend="pyav",
|
||||
return_uint8=False,
|
||||
is_depth=True,
|
||||
)
|
||||
depth_encoder = self._depth_encoder_configs[video_key]
|
||||
frames = dequantize_depth(
|
||||
frames,
|
||||
depth_min=depth_encoder.depth_min,
|
||||
depth_max=depth_encoder.depth_max,
|
||||
shift=depth_encoder.shift,
|
||||
use_log=depth_encoder.use_log,
|
||||
output_unit=self._depth_output_unit,
|
||||
)
|
||||
else:
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
query_ts,
|
||||
self.tolerance_s,
|
||||
decoder_cache=self.video_decoder_cache,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
|
||||
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
||||
|
||||
|
||||
Reference in New Issue
Block a user