diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index b809e71d9..2f23b213a 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -35,12 +35,17 @@ class DatasetConfig: revision: str | None = None use_imagenet_stats: bool = True video_backend: str = field(default_factory=get_safe_default_video_backend) - # When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0). + # When True, RGB video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0). # This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion. return_uint8: bool = False + # Physical unit depth maps are dequantized to at load time: "mm" (millimetres) or "m" (metres). + # Has no effect on datasets without depth cameras. + depth_output_unit: str = "mm" streaming: bool = False def __post_init__(self) -> None: + if self.depth_output_unit not in ("m", "mm"): + raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}") if self.episodes is not None: if any(ep < 0 for ep in self.episodes): raise ValueError( diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index 927cd9f8c..fe9a51eb7 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -54,6 +54,7 @@ class DatasetReader: delta_timestamps: dict[str, list[float]] | None, image_transforms: Callable | None, return_uint8: bool = False, + depth_output_unit: str = "mm", ): """Initialize the reader with metadata, filtering, and transform config. @@ -71,6 +72,10 @@ class DatasetReader: relative timestamp offsets for temporal context windows. image_transforms: Optional torchvision v2 transform applied to visual features. + return_uint8: If True, return RGB video frames as raw uint8 tensors + instead of normalized float32. + depth_output_unit: Physical unit depth maps are dequantized to + (``"m"`` or ``"mm"``). Defaults to ``"mm"``. """ self._meta = meta self.root = root @@ -79,6 +84,7 @@ class DatasetReader: self._video_backend = video_backend self._image_transforms = image_transforms self._return_uint8 = return_uint8 + self._depth_output_unit = depth_output_unit self.hf_dataset: datasets.Dataset | None = None self._absolute_to_relative_idx: dict[int, int] | None = None @@ -266,6 +272,7 @@ class DatasetReader: depth_max=depth_encoder.depth_max, shift=depth_encoder.shift, use_log=depth_encoder.use_log, + output_unit=self._depth_output_unit, ) return vid_key, frames.squeeze(0) diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index cbbe83dc8..9f1fe6530 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -96,6 +96,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas revision=cfg.dataset.revision, video_backend=cfg.dataset.video_backend, return_uint8=True, + depth_output_unit=cfg.dataset.depth_output_unit, tolerance_s=cfg.tolerance_s, ) else: diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b6ddcee83..6e757ecd2 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -58,6 +58,7 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos: bool = True, video_backend: str | None = None, return_uint8: bool = False, + depth_output_unit: str = "mm", batch_encoding_size: int = 1, camera_encoder: VideoEncoderConfig | None = None, depth_encoder: DepthEncoderConfig | None = None, @@ -212,6 +213,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.revision = revision if revision else CODEBASE_VERSION self._video_backend = video_backend if video_backend else get_safe_default_video_backend() self._return_uint8 = return_uint8 + self._depth_output_unit = depth_output_unit self._batch_encoding_size = batch_encoding_size self._encoder_threads = encoder_threads @@ -252,6 +254,7 @@ class LeRobotDataset(torch.utils.data.Dataset): delta_timestamps=delta_timestamps, image_transforms=image_transforms, return_uint8=self._return_uint8, + depth_output_unit=self._depth_output_unit, ) # Load actual data @@ -321,6 +324,7 @@ class LeRobotDataset(torch.utils.data.Dataset): delta_timestamps=self.delta_timestamps, image_transforms=self.image_transforms, return_uint8=self._return_uint8, + depth_output_unit=self._depth_output_unit, ) return self.reader @@ -722,6 +726,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.episodes = None obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend() obj._return_uint8 = False + obj._depth_output_unit = "mm" obj._batch_encoding_size = batch_encoding_size obj._encoder_threads = encoder_threads @@ -820,6 +825,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.episodes = None obj._video_backend = video_backend if video_backend else get_safe_default_video_backend() obj._return_uint8 = False + obj._depth_output_unit = "mm" obj._batch_encoding_size = batch_encoding_size if obj._requested_root is not None: