feat(stats units): rescaling stats when loading a dataset so that the stats are given in the requested unit

This commit is contained in:
CarolinePascal
2026-06-30 22:11:14 +02:00
parent 43a12b82d4
commit 2d2f93607f
3 changed files with 24 additions and 1 deletions
+22 -1
View File
@@ -26,12 +26,13 @@ import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import snapshot_download
from lerobot.configs import VideoEncoderConfig
from lerobot.configs import DEPTH_METER_UNIT, VideoEncoderConfig
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
from lerobot.utils.feature_utils import _validate_feature_names
from lerobot.utils.utils import flatten_dict
from .compute_stats import aggregate_stats
from .depth_utils import MM_PER_METRE
from .feature_utils import create_empty_dataset_info
from .io_utils import (
get_file_size_in_mb,
@@ -358,6 +359,26 @@ class LeRobotDatasetMetadata:
return [key for key, ft in self.features.items() if _is_depth(ft)]
def rescale_depth_stats(self, output_unit: str) -> None:
"""Rescale depth feature stats in place from their recorded unit to ``output_unit``.
Depth stats are stored in the unit the frames were recorded in
(``features[key]["info"]["depth_unit"]``), while frames are returned in
``output_unit`` on read. This converts the unit-bearing stat entries so
stats match the frames consumers see.
"""
if self.stats is None:
return
for key in self.depth_keys:
stored_unit = (self.features[key].get("info") or {}).get("depth_unit")
if stored_unit is None or stored_unit == output_unit or key not in self.stats:
continue
factor = MM_PER_METRE if stored_unit == DEPTH_METER_UNIT else 1.0 / MM_PER_METRE
self.stats[key] = {
stat: value if stat == "count" else value * factor
for stat, value in self.stats[key].items()
}
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
+1
View File
@@ -224,6 +224,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
self.root = self.meta.root
self.revision = self.meta.revision
self.meta.rescale_depth_stats(self._depth_output_unit)
if episodes is not None and any(
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
@@ -310,6 +310,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
)
self.root = self.meta.root
self.revision = self.meta.revision
self.meta.rescale_depth_stats(self._depth_output_unit)
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)