fix(streaming): adding support for dequantization in streaming_dataset.py

This commit is contained in:
CarolinePascal
2026-06-24 19:11:30 +02:00
parent 368830fe8e
commit ea424a7c71
+40 -7
View File
@@ -22,9 +22,11 @@ import numpy as np
import torch
from datasets import load_dataset
from lerobot.configs import DEFAULT_DEPTH_UNIT, DepthEncoderConfig
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from .depth_utils import dequantize_depth
from .feature_utils import get_delta_indices
from .io_utils import item_to_torch
from .utils import (
@@ -35,6 +37,7 @@ from .utils import (
)
from .video_utils import (
VideoDecoderCache,
decode_video_frames,
decode_video_frames_torchcodec,
)
@@ -252,6 +255,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
rng: np.random.Generator | None = None,
shuffle: bool = True,
return_uint8: bool = False,
depth_output_unit: str = DEFAULT_DEPTH_UNIT,
):
"""Initialize a StreamingLeRobotDataset.
@@ -272,6 +276,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
seed (int, optional): Reproducibility random seed.
rng (np.random.Generator | None, optional): Random number generator.
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
depth_output_unit (str, optional): Physical unit depth maps are dequantized to ("m" or "mm").
Defaults to "mm".
"""
super().__init__()
self.repo_id = repo_id
@@ -290,6 +296,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.streaming = streaming
self.buffer_size = buffer_size
self._return_uint8 = return_uint8
self._depth_output_unit = depth_output_unit
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
@@ -306,6 +313,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
vid_key: DepthEncoderConfig.from_video_info(self.meta.features[vid_key].get("info"))
for vid_key in self.meta.depth_keys
}
self.delta_timestamps = None
self.delta_indices = None
@@ -554,13 +566,34 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
for video_key, query_ts in query_timestamps.items():
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
frames = decode_video_frames_torchcodec(
video_path,
query_ts,
self.tolerance_s,
decoder_cache=self.video_decoder_cache,
return_uint8=self._return_uint8,
)
if video_key in self.meta.depth_keys:
# Depth maps are 12-bit quantized and only decodable via pyav; dequantize back
# to physical units to match the non-streaming reader.
frames = decode_video_frames(
video_path,
query_ts,
self.tolerance_s,
backend="pyav",
return_uint8=False,
is_depth=True,
)
depth_encoder = self._depth_encoder_configs[video_key]
frames = dequantize_depth(
frames,
depth_min=depth_encoder.depth_min,
depth_max=depth_encoder.depth_max,
shift=depth_encoder.shift,
use_log=depth_encoder.use_log,
output_unit=self._depth_output_unit,
)
else:
frames = decode_video_frames_torchcodec(
video_path,
query_ts,
self.tolerance_s,
decoder_cache=self.video_decoder_cache,
return_uint8=self._return_uint8,
)
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames