mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-01 07:07:08 +00:00
feat(stats units): rescaling stats when loading a dataset so that the stats are given in the requested unit
This commit is contained in:
@@ -26,12 +26,13 @@ import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.configs import DEPTH_METER_UNIT, VideoEncoderConfig
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
|
||||
from lerobot.utils.feature_utils import _validate_feature_names
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
|
||||
from .compute_stats import aggregate_stats
|
||||
from .depth_utils import MM_PER_METRE
|
||||
from .feature_utils import create_empty_dataset_info
|
||||
from .io_utils import (
|
||||
get_file_size_in_mb,
|
||||
@@ -358,6 +359,26 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
return [key for key, ft in self.features.items() if _is_depth(ft)]
|
||||
|
||||
def rescale_depth_stats(self, output_unit: str) -> None:
|
||||
"""Rescale depth feature stats in place from their recorded unit to ``output_unit``.
|
||||
|
||||
Depth stats are stored in the unit the frames were recorded in
|
||||
(``features[key]["info"]["depth_unit"]``), while frames are returned in
|
||||
``output_unit`` on read. This converts the unit-bearing stat entries so
|
||||
stats match the frames consumers see.
|
||||
"""
|
||||
if self.stats is None:
|
||||
return
|
||||
for key in self.depth_keys:
|
||||
stored_unit = (self.features[key].get("info") or {}).get("depth_unit")
|
||||
if stored_unit is None or stored_unit == output_unit or key not in self.stats:
|
||||
continue
|
||||
factor = MM_PER_METRE if stored_unit == DEPTH_METER_UNIT else 1.0 / MM_PER_METRE
|
||||
self.stats[key] = {
|
||||
stat: value if stat == "count" else value * factor
|
||||
for stat, value in self.stats[key].items()
|
||||
}
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
|
||||
@@ -224,6 +224,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
self.root = self.meta.root
|
||||
self.revision = self.meta.revision
|
||||
self.meta.rescale_depth_stats(self._depth_output_unit)
|
||||
|
||||
if episodes is not None and any(
|
||||
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
|
||||
|
||||
@@ -310,6 +310,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
)
|
||||
self.root = self.meta.root
|
||||
self.revision = self.meta.revision
|
||||
self.meta.rescale_depth_stats(self._depth_output_unit)
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user