From ea424a7c7136cf1eec0f15ba721233d47f33147a Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Wed, 24 Jun 2026 19:11:30 +0200 Subject: [PATCH] fix(streaming): adding support for dequantization in streaming_dataset.py --- src/lerobot/datasets/streaming_dataset.py | 47 +++++++++++++++++++---- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 3c1e4a73c..4c4ae59bf 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -22,9 +22,11 @@ import numpy as np import torch from datasets import load_dataset +from lerobot.configs import DEFAULT_DEPTH_UNIT, DepthEncoderConfig from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata +from .depth_utils import dequantize_depth from .feature_utils import get_delta_indices from .io_utils import item_to_torch from .utils import ( @@ -35,6 +37,7 @@ from .utils import ( ) from .video_utils import ( VideoDecoderCache, + decode_video_frames, decode_video_frames_torchcodec, ) @@ -252,6 +255,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): rng: np.random.Generator | None = None, shuffle: bool = True, return_uint8: bool = False, + depth_output_unit: str = DEFAULT_DEPTH_UNIT, ): """Initialize a StreamingLeRobotDataset. @@ -272,6 +276,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): seed (int, optional): Reproducibility random seed. rng (np.random.Generator | None, optional): Random number generator. shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True. + depth_output_unit (str, optional): Physical unit depth maps are dequantized to ("m" or "mm"). + Defaults to "mm". """ super().__init__() self.repo_id = repo_id @@ -290,6 +296,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.streaming = streaming self.buffer_size = buffer_size self._return_uint8 = return_uint8 + self._depth_output_unit = depth_output_unit # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None @@ -306,6 +313,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # Check version check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) + self._depth_encoder_configs: dict[str, DepthEncoderConfig] = { + vid_key: DepthEncoderConfig.from_video_info(self.meta.features[vid_key].get("info")) + for vid_key in self.meta.depth_keys + } + self.delta_timestamps = None self.delta_indices = None @@ -554,13 +566,34 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): for video_key, query_ts in query_timestamps.items(): root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}" - frames = decode_video_frames_torchcodec( - video_path, - query_ts, - self.tolerance_s, - decoder_cache=self.video_decoder_cache, - return_uint8=self._return_uint8, - ) + if video_key in self.meta.depth_keys: + # Depth maps are 12-bit quantized and only decodable via pyav; dequantize back + # to physical units to match the non-streaming reader. + frames = decode_video_frames( + video_path, + query_ts, + self.tolerance_s, + backend="pyav", + return_uint8=False, + is_depth=True, + ) + depth_encoder = self._depth_encoder_configs[video_key] + frames = dequantize_depth( + frames, + depth_min=depth_encoder.depth_min, + depth_max=depth_encoder.depth_max, + shift=depth_encoder.shift, + use_log=depth_encoder.use_log, + output_unit=self._depth_output_unit, + ) + else: + frames = decode_video_frames_torchcodec( + video_path, + query_ts, + self.tolerance_s, + decoder_cache=self.video_decoder_cache, + return_uint8=self._return_uint8, + ) item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames