Compare commits

...

11 Commits

Author SHA1 Message Date
CarolinePascal 37866f8014 test(rerun): fixing rerun tests 2026-07-01 20:23:56 +02:00
CarolinePascal 0584866f85 fix(streaming dataset): extending support for depth units to streaming datasets 2026-07-01 20:09:09 +02:00
CarolinePascal 15d94e6108 feat(unit getter): adding a proper output_depth_unit getter to LeRobotDataset for cleaner integration 2026-07-01 20:00:20 +02:00
CarolinePascal 0d52d371be feat(rerun unit): adding correct depth unit display for rerun (foxglove does not support units yet) 2026-07-01 19:39:43 +02:00
CarolinePascal bb066435bf chore(infer_depth_unit): moving the depth unit inference utility in a more accessible location 2026-07-01 19:38:30 +02:00
CarolinePascal afa189fc72 feat(warning): adding a warning when depth unit is not specified in the dataset 2026-07-01 19:16:00 +02:00
CarolinePascal bcc71bf73b chore(format): formating code 2026-07-01 19:16:00 +02:00
CarolinePascal b20c85b85c tests(unit): adapting and extending depth tests to units manipulations 2026-07-01 19:16:00 +02:00
CarolinePascal 0f32152aa5 feat(stats units): rescaling stats when loading a dataset so that the stats are given in the requested unit 2026-07-01 19:16:00 +02:00
CarolinePascal a844eca500 feat(raw frame unit): adapting dataset reader so that raw depth frames are scaled according to the requested unit 2026-07-01 19:16:00 +02:00
CarolinePascal 006ca66a66 fix(depth unit): storing raw depth units in the dataset metadata for correct depth statistics and depth raw frames handling. The unit is stored as a string ("m","mm") under "depth_unit" at the same level as "is_depth_map". Unit is inferred from the depth frame type. 2026-07-01 19:15:59 +02:00
14 changed files with 229 additions and 19 deletions
+6
View File
@@ -34,6 +34,8 @@ from .types import (
) )
from .video import ( from .video import (
DEFAULT_DEPTH_UNIT, DEFAULT_DEPTH_UNIT,
DEPTH_METER_UNIT,
DEPTH_MILLIMETER_UNIT,
VALID_VIDEO_CODECS, VALID_VIDEO_CODECS,
VIDEO_ENCODER_INFO_KEYS, VIDEO_ENCODER_INFO_KEYS,
DepthEncoderConfig, DepthEncoderConfig,
@@ -41,6 +43,7 @@ from .video import (
VideoEncoderConfig, VideoEncoderConfig,
depth_encoder_defaults, depth_encoder_defaults,
encoder_config_from_video_info, encoder_config_from_video_info,
infer_depth_unit,
rgb_encoder_defaults, rgb_encoder_defaults,
) )
@@ -70,8 +73,11 @@ __all__ = [
"depth_encoder_defaults", "depth_encoder_defaults",
# Factories # Factories
"encoder_config_from_video_info", "encoder_config_from_video_info",
"infer_depth_unit",
# Constants # Constants
"DEFAULT_DEPTH_UNIT", "DEFAULT_DEPTH_UNIT",
"DEPTH_METER_UNIT",
"DEPTH_MILLIMETER_UNIT",
"VALID_VIDEO_CODECS", "VALID_VIDEO_CODECS",
"VIDEO_ENCODER_INFO_KEYS", "VIDEO_ENCODER_INFO_KEYS",
] ]
+11
View File
@@ -22,6 +22,8 @@ import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, ClassVar, Self from typing import Any, ClassVar, Self
import numpy as np
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import require_package
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -65,6 +67,15 @@ DEPTH_METER_UNIT: str = "m"
DEPTH_MILLIMETER_UNIT: str = "mm" DEPTH_MILLIMETER_UNIT: str = "mm"
DEFAULT_DEPTH_UNIT: str = DEPTH_MILLIMETER_UNIT DEFAULT_DEPTH_UNIT: str = DEPTH_MILLIMETER_UNIT
def infer_depth_unit(dtype: np.dtype | type) -> str:
"""Infer the physical unit of raw depth frames from their dtype.
Floating-point frames are assumed to be in metres, integer frames in millimetres.
"""
return DEPTH_METER_UNIT if np.issubdtype(np.dtype(dtype), np.floating) else DEPTH_MILLIMETER_UNIT
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``. # Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"}) DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
+1 -1
View File
@@ -509,7 +509,7 @@ def compute_episode_stats(
For 'image'/'video' features, stats are computed per channel and kept with a For 'image'/'video' features, stats are computed per channel and kept with a
leading channel axis (e.g. shape (3, 1, 1) for RGB). RGB stats are divided by leading channel axis (e.g. shape (3, 1, 1) for RGB). RGB stats are divided by
255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) skip 255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) skip
this rescaling and remain in their stored units. this rescaling and remain in their stored units (stored in ``depth_unit``).
""" """
if quantile_list is None: if quantile_list is None:
quantile_list = DEFAULT_QUANTILES quantile_list = DEFAULT_QUANTILES
+31 -1
View File
@@ -26,12 +26,13 @@ import pyarrow as pa
import pyarrow.parquet as pq import pyarrow.parquet as pq
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from lerobot.configs import VideoEncoderConfig from lerobot.configs import DEPTH_METER_UNIT, VideoEncoderConfig
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
from lerobot.utils.feature_utils import _validate_feature_names from lerobot.utils.feature_utils import _validate_feature_names
from lerobot.utils.utils import flatten_dict from lerobot.utils.utils import flatten_dict
from .compute_stats import aggregate_stats from .compute_stats import aggregate_stats
from .depth_utils import MM_PER_METRE
from .feature_utils import create_empty_dataset_info from .feature_utils import create_empty_dataset_info
from .io_utils import ( from .io_utils import (
get_file_size_in_mb, get_file_size_in_mb,
@@ -358,6 +359,35 @@ class LeRobotDatasetMetadata:
return [key for key, ft in self.features.items() if _is_depth(ft)] return [key for key, ft in self.features.items() if _is_depth(ft)]
def rescale_depth_stats(self, output_unit: str) -> None:
"""Rescale depth feature stats in place from their recorded unit to ``output_unit``.
Depth stats are stored in the unit the frames were recorded in
(``features[key]["info"]["depth_unit"]``), while frames are returned in
``output_unit`` on read. This converts the unit-bearing stat entries so
stats match the frames consumers see.
"""
missing_unit_keys = [
key for key in self.depth_keys if (self.features[key].get("info") or {}).get("depth_unit") is None
]
if missing_unit_keys:
logging.warning(
f"Depth feature(s) {missing_unit_keys} have no recorded 'depth_unit' in their info. "
f"Depth maps and stats for these keys will be returned AS IS, with no unit conversion "
f"to the requested output unit {output_unit!r}. Re-record the dataset or set 'depth_unit' "
f"in the feature info (meta/info.json) to enable conversion."
)
if self.stats is None:
return
for key in self.depth_keys:
stored_unit = (self.features[key].get("info") or {}).get("depth_unit")
if stored_unit is None or stored_unit == output_unit or key not in self.stats:
continue
factor = MM_PER_METRE if stored_unit == DEPTH_METER_UNIT else 1.0 / MM_PER_METRE
self.stats[key] = {
stat: value if stat == "count" else value * factor for stat, value in self.stats[key].items()
}
@property @property
def camera_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method).""" """Keys to access visual modalities (regardless of their storage method)."""
+20 -2
View File
@@ -22,10 +22,14 @@ from pathlib import Path
import datasets import datasets
import torch import torch
from lerobot.configs import DEFAULT_DEPTH_UNIT, DepthEncoderConfig from lerobot.configs import (
DEFAULT_DEPTH_UNIT,
DEPTH_METER_UNIT,
DepthEncoderConfig,
)
from .dataset_metadata import LeRobotDatasetMetadata from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import dequantize_depth from .depth_utils import MM_PER_METRE, dequantize_depth
from .feature_utils import ( from .feature_utils import (
check_delta_timestamps, check_delta_timestamps,
get_delta_indices, get_delta_indices,
@@ -102,6 +106,13 @@ class DatasetReader:
for vid_key in self._meta.depth_keys for vid_key in self._meta.depth_keys
} }
# Get the input unit of each depth feature stored as raw images.
self._image_depth_units: dict[str, str | None] = {
key: (self._meta.features[key].get("info") or {}).get("depth_unit")
for key in self._meta.depth_keys
if key in self._meta.image_keys
}
def set_image_transforms(self, image_transforms: Callable | None) -> None: def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations.""" """Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms): if image_transforms is not None and not callable(image_transforms):
@@ -329,6 +340,13 @@ class DatasetReader:
continue continue
item[cam] = self._image_transforms(item[cam]) item[cam] = self._image_transforms(item[cam])
# Convert depth features to the output unit.
for key, stored_unit in self._image_depth_units.items():
if key in item and stored_unit is not None and stored_unit != self._depth_output_unit:
item[key] = (
item[key] * MM_PER_METRE if stored_unit == DEPTH_METER_UNIT else item[key] / MM_PER_METRE
)
# Add task as a string # Add task as a string
task_idx = item["task_index"].item() task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name item["task"] = self._meta.tasks.iloc[task_idx].name
+10
View File
@@ -36,6 +36,7 @@ from lerobot.configs import (
RGBEncoderConfig, RGBEncoderConfig,
VideoEncoderConfig, VideoEncoderConfig,
depth_encoder_defaults, depth_encoder_defaults,
infer_depth_unit,
rgb_encoder_defaults, rgb_encoder_defaults,
) )
@@ -209,6 +210,15 @@ class DatasetWriter:
self.episode_buffer["timestamp"].append(timestamp) self.episode_buffer["timestamp"].append(timestamp)
self.episode_buffer["task"].append(frame.pop("task")) self.episode_buffer["task"].append(frame.pop("task"))
# Record each depth feature's input unit once, inferred from the first frame's dtype.
if frame_index == 0:
for depth_key in self._meta.depth_keys:
if depth_key not in frame:
continue
info = self._meta.features[depth_key].setdefault("info", {})
if info.get("depth_unit") is None:
info["depth_unit"] = infer_depth_unit(np.asarray(frame[depth_key]).dtype)
# Start streaming encoder on first frame of episode # Start streaming encoder on first frame of episode
if frame_index == 0 and self._streaming_encoder is not None: if frame_index == 0 and self._streaming_encoder is not None:
self._streaming_encoder.start_episode( self._streaming_encoder.start_episode(
+8 -11
View File
@@ -34,12 +34,13 @@ from lerobot.configs.video import (
DEPTH_METER_UNIT, DEPTH_METER_UNIT,
DEPTH_MILLIMETER_UNIT, DEPTH_MILLIMETER_UNIT,
DEPTH_QMAX, DEPTH_QMAX,
infer_depth_unit,
) )
from .image_writer import squeeze_single_channel from .image_writer import squeeze_single_channel
from .pyav_utils import write_u16_plane from .pyav_utils import write_u16_plane
_MM_PER_METRE = 1000.0 MM_PER_METRE = 1000.0
_UINT16_MAX = 65535 _UINT16_MAX = 65535
@@ -57,11 +58,7 @@ def _depth_input_to_float32_and_unit(
input_unit: Literal["auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT], input_unit: Literal["auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT],
) -> tuple[NDArray[np.float32], Literal[DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT]]: ) -> tuple[NDArray[np.float32], Literal[DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT]]:
"""Convert depth to float32 in the chosen unit, and return the resolved unit.""" """Convert depth to float32 in the chosen unit, and return the resolved unit."""
resolved_unit = ( resolved_unit = infer_depth_unit(depth.dtype) if input_unit == "auto" else input_unit
(DEPTH_METER_UNIT if np.issubdtype(depth.dtype, np.floating) else DEPTH_MILLIMETER_UNIT)
if input_unit == "auto"
else input_unit
)
return depth.astype(np.float32, order="K"), resolved_unit return depth.astype(np.float32, order="K"), resolved_unit
@@ -126,12 +123,12 @@ def quantize_depth(
# Convert depth_min, depth_max, and shift to the resolved input unit. # Convert depth_min, depth_max, and shift to the resolved input unit.
depth_min_u = ( depth_min_u = (
np.float32(depth_min) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_min * _MM_PER_METRE) np.float32(depth_min) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_min * MM_PER_METRE)
) )
depth_max_u = ( depth_max_u = (
np.float32(depth_max) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_max * _MM_PER_METRE) np.float32(depth_max) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_max * MM_PER_METRE)
) )
shift_u = np.float32(shift) if resolved_unit == DEPTH_METER_UNIT else np.float32(shift * _MM_PER_METRE) shift_u = np.float32(shift) if resolved_unit == DEPTH_METER_UNIT else np.float32(shift * MM_PER_METRE)
# Normalization and quantization is performed in the resolved input unit. # Normalization and quantization is performed in the resolved input unit.
if use_log: if use_log:
@@ -236,7 +233,7 @@ def dequantize_depth(
# mm path: round + clamp in float32, skipping the uint16 round-trip # mm path: round + clamp in float32, skipping the uint16 round-trip
# when returning a tensor (torch.uint16 is poorly supported). # when returning a tensor (torch.uint16 is poorly supported).
buf.mul_(_MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX) buf.mul_(MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
if output_tensor: if output_tensor:
return buf return buf
return buf.cpu().numpy().astype(np.uint16, copy=False) return buf.cpu().numpy().astype(np.uint16, copy=False)
@@ -259,7 +256,7 @@ def dequantize_depth(
if output_unit == DEPTH_METER_UNIT: if output_unit == DEPTH_METER_UNIT:
return torch.from_numpy(buf) if output_tensor else buf return torch.from_numpy(buf) if output_tensor else buf
np.multiply(buf, _MM_PER_METRE, out=buf) np.multiply(buf, MM_PER_METRE, out=buf)
np.rint(buf, out=buf) np.rint(buf, out=buf)
np.clip(buf, 0.0, _UINT16_MAX, out=buf) np.clip(buf, 0.0, _UINT16_MAX, out=buf)
if output_tensor: if output_tensor:
+6
View File
@@ -224,6 +224,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
self.root = self.meta.root self.root = self.meta.root
self.revision = self.meta.revision self.revision = self.meta.revision
self.meta.rescale_depth_stats(self._depth_output_unit)
if episodes is not None and any( if episodes is not None and any(
episode >= self.meta.total_episodes or episode < 0 for episode in episodes episode >= self.meta.total_episodes or episode < 0 for episode in episodes
@@ -350,6 +351,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Frames per second used during data collection.""" """Frames per second used during data collection."""
return self.meta.fps return self.meta.fps
@property
def depth_output_unit(self) -> str:
"""Physical unit (``"m"`` or ``"mm"``) depth maps and statistics are returned in on read."""
return self._depth_output_unit
@property @property
def num_frames(self) -> int: def num_frames(self) -> int:
"""Number of frames in selected episodes.""" """Number of frames in selected episodes."""
+24 -2
View File
@@ -22,11 +22,11 @@ import numpy as np
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from lerobot.configs import DEFAULT_DEPTH_UNIT, DepthEncoderConfig from lerobot.configs import DEFAULT_DEPTH_UNIT, DEPTH_METER_UNIT, DepthEncoderConfig
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from .depth_utils import dequantize_depth from .depth_utils import MM_PER_METRE, dequantize_depth
from .feature_utils import get_delta_indices from .feature_utils import get_delta_indices
from .io_utils import item_to_torch from .io_utils import item_to_torch
from .utils import ( from .utils import (
@@ -310,6 +310,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
) )
self.root = self.meta.root self.root = self.meta.root
self.revision = self.meta.revision self.revision = self.meta.revision
self.meta.rescale_depth_stats(self._depth_output_unit)
# Check version # Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
@@ -318,6 +319,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
for vid_key in self.meta.depth_keys for vid_key in self.meta.depth_keys
} }
# Input unit of each depth feature stored as raw images (dequantized separately from videos).
self._image_depth_units: dict[str, str | None] = {
key: (self.meta.features[key].get("info") or {}).get("depth_unit")
for key in self.meta.depth_keys
if key in self.meta.image_keys
}
self.delta_timestamps = None self.delta_timestamps = None
self.delta_indices = None self.delta_indices = None
@@ -348,6 +356,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
def fps(self): def fps(self):
return self.meta.fps return self.meta.fps
@property
def depth_output_unit(self) -> str:
"""Physical unit (``"m"`` or ``"mm"``) depth maps are returned in on read."""
return self._depth_output_unit
@staticmethod @staticmethod
def _iter_random_indices( def _iter_random_indices(
rng: np.random.Generator, buffer_size: int, random_batch_size=100 rng: np.random.Generator, buffer_size: int, random_batch_size=100
@@ -530,6 +543,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
for update in updates: for update in updates:
result.update(update) result.update(update)
# Convert raw-image depth features to the output unit (video depth is already converted).
for key, stored_unit in self._image_depth_units.items():
if key in result and stored_unit is not None and stored_unit != self._depth_output_unit:
result[key] = (
result[key] * MM_PER_METRE
if stored_unit == DEPTH_METER_UNIT
else result[key] / MM_PER_METRE
)
result["task"] = self.meta.tasks.iloc[item["task_index"]].name result["task"] = self.meta.tasks.iloc[item["task_index"]].name
yield result yield result
@@ -84,6 +84,7 @@ import torch
import torch.utils.data import torch.utils.data
import tqdm import tqdm
from lerobot.configs import DEPTH_MILLIMETER_UNIT
from lerobot.datasets import LeRobotDataset from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD, SUCCESS from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD, SUCCESS
from lerobot.utils.utils import init_logging from lerobot.utils.utils import init_logging
@@ -228,6 +229,9 @@ def visualize_dataset(
logging.info("Logging to Rerun") logging.info("Logging to Rerun")
# Depth frames and stats are dequantized to the dataset's depth_output_unit on load.
depth_meter = 1000.0 if dataset.depth_output_unit == DEPTH_MILLIMETER_UNIT else 1.0
# Use the dataset's q01/q99 depth statistics for robust depth range bounds # Use the dataset's q01/q99 depth statistics for robust depth range bounds
depth_ranges = {} depth_ranges = {}
for key in dataset.meta.depth_keys: for key in dataset.meta.depth_keys:
@@ -254,6 +258,7 @@ def visualize_dataset(
depth = to_hwc_float32_numpy(batch[key][i]) depth = to_hwc_float32_numpy(batch[key][i])
depth_entity = rr.DepthImage( depth_entity = rr.DepthImage(
depth, depth,
meter=depth_meter,
colormap=rr.components.Colormap.Viridis, colormap=rr.components.Colormap.Viridis,
depth_range=depth_ranges.get(key), depth_range=depth_ranges.get(key),
) )
+8 -1
View File
@@ -24,6 +24,7 @@ import os
import numpy as np import numpy as np
from lerobot.configs import DEPTH_MILLIMETER_UNIT, infer_depth_unit
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
@@ -161,7 +162,13 @@ def log_rerun_data(
observation_paths.add(key) observation_paths.add(key)
else: else:
if arr.shape[-1] == 1: if arr.shape[-1] == 1:
img_entity = rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis) # At record time, the depth unit is inferred from the frame type.
depth_unit = infer_depth_unit(arr.dtype)
img_entity = rr.DepthImage(
arr,
meter=1000.0 if depth_unit == DEPTH_MILLIMETER_UNIT else 1.0,
colormap=rr.components.Colormap.Viridis,
)
else: else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr) img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True) rr.log(key, entity=img_entity, static=True)
+89
View File
@@ -32,6 +32,7 @@ from lerobot.configs.video import (
) )
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
from lerobot.datasets.image_writer import image_array_to_pil_image, write_image from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
from lerobot.utils.constants import DEFAULT_FEATURES
from tests.fixtures.constants import ( from tests.fixtures.constants import (
DEFAULT_FPS, DEFAULT_FPS,
DUMMY_CAMERA_FEATURES, DUMMY_CAMERA_FEATURES,
@@ -245,3 +246,91 @@ class TestFeatureFileRouting:
dataset.save_episode() dataset.save_episode()
dataset.finalize() dataset.finalize()
class TestDepthUnitMetadata:
"""The depth unit is inferred once from dtype, stored in ``info``, and drives stats + reads."""
NUM_FRAMES = 4
def _record(self, root, features_factory, depth_dtype, value, use_videos):
from lerobot.datasets.lerobot_dataset import LeRobotDataset
features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=use_videos)
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID,
fps=DEFAULT_FPS,
features=features,
root=root,
use_videos=use_videos,
streaming_encoding=use_videos,
)
for _ in range(self.NUM_FRAMES):
frame: dict = {"task": "test"}
for key, ft in dataset.meta.features.items():
if key in DEFAULT_FEATURES:
continue
if key in dataset.meta.depth_keys:
frame[key] = np.full(ft["shape"], value, dtype=depth_dtype)
elif key in dataset.meta.camera_keys:
frame[key] = np.random.randint(0, 256, ft["shape"], dtype=np.uint8)
else:
frame[key] = np.zeros(ft["shape"], dtype=np.float32)
dataset.add_frame(frame)
return dataset
@pytest.mark.parametrize("use_videos", [False, True])
@pytest.mark.parametrize(
("depth_dtype", "value", "expected_unit"),
[(np.float32, 2.0, DEPTH_METER_UNIT), (np.uint16, 2000, DEPTH_MILLIMETER_UNIT)],
)
def test_recorded_unit_inferred_persisted_and_kept_in_stats(
self, tmp_path, features_factory, use_videos, depth_dtype, value, expected_unit
):
"""Unit is inferred from the first frame's dtype, drives stats (raw, never canonicalized), and survives a reload."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = self._record(tmp_path / "ds", features_factory, depth_dtype, value, use_videos)
assert dataset.meta.features[DEPTH_KEY]["info"]["depth_unit"] == expected_unit
dataset.save_episode()
mean = float(np.asarray(dataset.meta.stats[DEPTH_KEY]["mean"]).reshape(-1)[0])
np.testing.assert_allclose(mean, value, rtol=0.05)
dataset.finalize()
reloaded = LeRobotDataset(repo_id=DUMMY_REPO_ID, root=tmp_path / "ds")
assert reloaded.meta.features[DEPTH_KEY]["info"]["depth_unit"] == expected_unit
@pytest.mark.parametrize("use_videos", [False, True])
@pytest.mark.parametrize(
("output_unit", "expected"),
[(DEPTH_MILLIMETER_UNIT, 2000.0), (DEPTH_METER_UNIT, 2.0)],
)
def test_read_honors_output_unit_for_frames_and_stats(
self, tmp_path, features_factory, use_videos, output_unit, expected
):
"""Reloading with a ``depth_output_unit`` converts metre frames (image mode) and rescales stats while preserving count."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = self._record(tmp_path / "ds", features_factory, np.float32, 2.0, use_videos=use_videos)
dataset.save_episode()
count = float(np.asarray(dataset.meta.stats[DEPTH_KEY]["count"]).reshape(-1)[0])
dataset.finalize()
read_dataset = LeRobotDataset(
repo_id=DUMMY_REPO_ID, root=tmp_path / "ds", depth_output_unit=output_unit
)
stats = read_dataset.meta.stats[DEPTH_KEY]
np.testing.assert_allclose(float(np.asarray(stats["mean"]).reshape(-1)[0]), expected, rtol=0.05)
np.testing.assert_allclose(float(np.asarray(stats["count"]).reshape(-1)[0]), count)
if not use_videos:
depth = read_dataset[0][DEPTH_KEY]
assert torch.allclose(depth, torch.full_like(depth, expected))
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
stream_dataset = StreamingLeRobotDataset(
repo_id=DUMMY_REPO_ID, root=tmp_path / "ds", depth_output_unit=output_unit
)
stream_depth = next(iter(stream_dataset))[DEPTH_KEY]
assert torch.allclose(stream_depth, torch.full_like(stream_depth, expected))
+8
View File
@@ -26,6 +26,7 @@ import pytest
import torch import torch
from datasets import Dataset from datasets import Dataset
from lerobot.configs.video import infer_depth_unit
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import get_hf_features_from_features from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import flatten_dict, hf_transform_to_torch from lerobot.datasets.io_utils import flatten_dict, hf_transform_to_torch
@@ -535,6 +536,13 @@ def lerobot_dataset_factory(
chunks_size=chunks_size, chunks_size=chunks_size,
**info_kwargs, **info_kwargs,
) )
# This synthetic path skips add_frame, so record the depth unit the writer would
# have stored (dummy depth is uint16) to keep ``depth_unit`` present in info.json.
# Reassign a fresh info dict to avoid mutating the shared feature constants.
for ft in info.features.values():
ft_info = ft.get("info")
if ft_info is not None and ft_info.get("is_depth_map") and "depth_unit" not in ft_info:
ft["info"] = {**ft_info, "depth_unit": infer_depth_unit(np.uint16)}
if stats is None: if stats is None:
stats = stats_factory(features=info.features) stats = stats_factory(features=info.features)
if tasks is None: if tasks is None:
+2 -1
View File
@@ -50,8 +50,9 @@ def mock_rerun(monkeypatch):
return self return self
class DummyDepthImage: class DummyDepthImage:
def __init__(self, arr, colormap=None): def __init__(self, arr, meter=None, colormap=None):
self.arr = arr self.arr = arr
self.meter = meter
self.colormap = colormap self.colormap = colormap
def dummy_log(key, obj=None, **kwargs): def dummy_log(key, obj=None, **kwargs):