diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index e8e07301e..f4e1f6a31 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -22,10 +22,14 @@ from pathlib import Path import datasets import torch -from lerobot.configs import DEFAULT_DEPTH_UNIT, DepthEncoderConfig +from lerobot.configs import ( + DEFAULT_DEPTH_UNIT, + DEPTH_METER_UNIT, + DepthEncoderConfig, +) from .dataset_metadata import LeRobotDatasetMetadata -from .depth_utils import dequantize_depth +from .depth_utils import MM_PER_METRE, dequantize_depth from .feature_utils import ( check_delta_timestamps, get_delta_indices, @@ -102,6 +106,13 @@ class DatasetReader: for vid_key in self._meta.depth_keys } + # Get the input unit of each depth feature stored as raw images. + self._image_depth_units: dict[str, str | None] = { + key: (self._meta.features[key].get("info") or {}).get("depth_unit") + for key in self._meta.depth_keys + if key in self._meta.image_keys + } + def set_image_transforms(self, image_transforms: Callable | None) -> None: """Replace the transform applied to visual observations.""" if image_transforms is not None and not callable(image_transforms): @@ -329,6 +340,13 @@ class DatasetReader: continue item[cam] = self._image_transforms(item[cam]) + # Convert depth features to the output unit. + for key, stored_unit in self._image_depth_units.items(): + if key in item and stored_unit is not None and stored_unit != self._depth_output_unit: + item[key] = ( + item[key] * MM_PER_METRE if stored_unit == DEPTH_METER_UNIT else item[key] / MM_PER_METRE + ) + # Add task as a string task_idx = item["task_index"].item() item["task"] = self._meta.tasks.iloc[task_idx].name