feat(output unit): adding support for output unit specification at dataset reading/training time

Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
This commit is contained in:
CarolinePascal
2026-06-12 18:34:46 +02:00
parent 1a5fbb216f
commit 979f7b1187
4 changed files with 20 additions and 1 deletions
+6 -1
View File
@@ -35,12 +35,17 @@ class DatasetConfig:
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_video_backend)
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# When True, RGB video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
# Physical unit depth maps are dequantized to at load time: "mm" (millimetres) or "m" (metres).
# Has no effect on datasets without depth cameras.
depth_output_unit: str = "mm"
streaming: bool = False
def __post_init__(self) -> None:
if self.depth_output_unit not in ("m", "mm"):
raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}")
if self.episodes is not None:
if any(ep < 0 for ep in self.episodes):
raise ValueError(
+7
View File
@@ -54,6 +54,7 @@ class DatasetReader:
delta_timestamps: dict[str, list[float]] | None,
image_transforms: Callable | None,
return_uint8: bool = False,
depth_output_unit: str = "mm",
):
"""Initialize the reader with metadata, filtering, and transform config.
@@ -71,6 +72,10 @@ class DatasetReader:
relative timestamp offsets for temporal context windows.
image_transforms: Optional torchvision v2 transform applied to
visual features.
return_uint8: If True, return RGB video frames as raw uint8 tensors
instead of normalized float32.
depth_output_unit: Physical unit depth maps are dequantized to
(``"m"`` or ``"mm"``). Defaults to ``"mm"``.
"""
self._meta = meta
self.root = root
@@ -79,6 +84,7 @@ class DatasetReader:
self._video_backend = video_backend
self._image_transforms = image_transforms
self._return_uint8 = return_uint8
self._depth_output_unit = depth_output_unit
self.hf_dataset: datasets.Dataset | None = None
self._absolute_to_relative_idx: dict[int, int] | None = None
@@ -266,6 +272,7 @@ class DatasetReader:
depth_max=depth_encoder.depth_max,
shift=depth_encoder.shift,
use_log=depth_encoder.use_log,
output_unit=self._depth_output_unit,
)
return vid_key, frames.squeeze(0)
+1
View File
@@ -96,6 +96,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
depth_output_unit=cfg.dataset.depth_output_unit,
tolerance_s=cfg.tolerance_s,
)
else:
+6
View File
@@ -58,6 +58,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos: bool = True,
video_backend: str | None = None,
return_uint8: bool = False,
depth_output_unit: str = "mm",
batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
@@ -212,6 +213,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
self._return_uint8 = return_uint8
self._depth_output_unit = depth_output_unit
self._batch_encoding_size = batch_encoding_size
self._encoder_threads = encoder_threads
@@ -252,6 +254,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
return_uint8=self._return_uint8,
depth_output_unit=self._depth_output_unit,
)
# Load actual data
@@ -321,6 +324,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=self.delta_timestamps,
image_transforms=self.image_transforms,
return_uint8=self._return_uint8,
depth_output_unit=self._depth_output_unit,
)
return self.reader
@@ -722,6 +726,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.episodes = None
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
obj._return_uint8 = False
obj._depth_output_unit = "mm"
obj._batch_encoding_size = batch_encoding_size
obj._encoder_threads = encoder_threads
@@ -820,6 +825,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.episodes = None
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
obj._return_uint8 = False
obj._depth_output_unit = "mm"
obj._batch_encoding_size = batch_encoding_size
if obj._requested_root is not None: