diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index ea329668c..48c1e463a 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -26,12 +26,13 @@ import pyarrow as pa import pyarrow.parquet as pq from huggingface_hub import snapshot_download -from lerobot.configs import VideoEncoderConfig +from lerobot.configs import DEPTH_METER_UNIT, VideoEncoderConfig from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE from lerobot.utils.feature_utils import _validate_feature_names from lerobot.utils.utils import flatten_dict from .compute_stats import aggregate_stats +from .depth_utils import MM_PER_METRE from .feature_utils import create_empty_dataset_info from .io_utils import ( get_file_size_in_mb, @@ -358,6 +359,26 @@ class LeRobotDatasetMetadata: return [key for key, ft in self.features.items() if _is_depth(ft)] + def rescale_depth_stats(self, output_unit: str) -> None: + """Rescale depth feature stats in place from their recorded unit to ``output_unit``. + + Depth stats are stored in the unit the frames were recorded in + (``features[key]["info"]["depth_unit"]``), while frames are returned in + ``output_unit`` on read. This converts the unit-bearing stat entries so + stats match the frames consumers see. + """ + if self.stats is None: + return + for key in self.depth_keys: + stored_unit = (self.features[key].get("info") or {}).get("depth_unit") + if stored_unit is None or stored_unit == output_unit or key not in self.stats: + continue + factor = MM_PER_METRE if stored_unit == DEPTH_METER_UNIT else 1.0 / MM_PER_METRE + self.stats[key] = { + stat: value if stat == "count" else value * factor + for stat, value in self.stats[key].items() + } + @property def camera_keys(self) -> list[str]: """Keys to access visual modalities (regardless of their storage method).""" diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index f600f1804..672b6958b 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -224,6 +224,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ) self.root = self.meta.root self.revision = self.meta.revision + self.meta.rescale_depth_stats(self._depth_output_unit) if episodes is not None and any( episode >= self.meta.total_episodes or episode < 0 for episode in episodes diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 4c4ae59bf..7d63a618b 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -310,6 +310,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): ) self.root = self.meta.root self.revision = self.meta.revision + self.meta.rescale_depth_stats(self._depth_output_unit) # Check version check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)