From bb066435bf0956c4fa90fea638a88346093efd57 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Wed, 1 Jul 2026 19:38:30 +0200 Subject: [PATCH] chore(infer_depth_unit): moving the depth unit inference utility in a more accessible location --- src/lerobot/configs/__init__.py | 2 ++ src/lerobot/configs/video.py | 11 +++++++++++ src/lerobot/datasets/dataset_writer.py | 2 +- src/lerobot/datasets/depth_utils.py | 9 +-------- tests/fixtures/dataset_factories.py | 2 +- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/lerobot/configs/__init__.py b/src/lerobot/configs/__init__.py index 20f81fb18..c32e3368b 100644 --- a/src/lerobot/configs/__init__.py +++ b/src/lerobot/configs/__init__.py @@ -43,6 +43,7 @@ from .video import ( VideoEncoderConfig, depth_encoder_defaults, encoder_config_from_video_info, + infer_depth_unit, rgb_encoder_defaults, ) @@ -72,6 +73,7 @@ __all__ = [ "depth_encoder_defaults", # Factories "encoder_config_from_video_info", + "infer_depth_unit", # Constants "DEFAULT_DEPTH_UNIT", "DEPTH_METER_UNIT", diff --git a/src/lerobot/configs/video.py b/src/lerobot/configs/video.py index 3ea834508..7b76e0449 100644 --- a/src/lerobot/configs/video.py +++ b/src/lerobot/configs/video.py @@ -22,6 +22,8 @@ import logging from dataclasses import dataclass, field from typing import Any, ClassVar, Self +import numpy as np + from lerobot.utils.import_utils import require_package logger = logging.getLogger(__name__) @@ -65,6 +67,15 @@ DEPTH_METER_UNIT: str = "m" DEPTH_MILLIMETER_UNIT: str = "mm" DEFAULT_DEPTH_UNIT: str = DEPTH_MILLIMETER_UNIT + +def infer_depth_unit(dtype: np.dtype | type) -> str: + """Infer the physical unit of raw depth frames from their dtype. + + Floating-point frames are assumed to be in metres, integer frames in millimetres. + """ + return DEPTH_METER_UNIT if np.issubdtype(np.dtype(dtype), np.floating) else DEPTH_MILLIMETER_UNIT + + # Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.``. DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"}) diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py index f8bf0eddb..a6049312f 100644 --- a/src/lerobot/datasets/dataset_writer.py +++ b/src/lerobot/datasets/dataset_writer.py @@ -36,12 +36,12 @@ from lerobot.configs import ( RGBEncoderConfig, VideoEncoderConfig, depth_encoder_defaults, + infer_depth_unit, rgb_encoder_defaults, ) from .compute_stats import compute_episode_stats from .dataset_metadata import LeRobotDatasetMetadata -from .depth_utils import infer_depth_unit from .feature_utils import ( get_hf_features_from_features, validate_episode_buffer, diff --git a/src/lerobot/datasets/depth_utils.py b/src/lerobot/datasets/depth_utils.py index 04aa9a54b..a4e187eb4 100644 --- a/src/lerobot/datasets/depth_utils.py +++ b/src/lerobot/datasets/depth_utils.py @@ -34,6 +34,7 @@ from lerobot.configs.video import ( DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT, DEPTH_QMAX, + infer_depth_unit, ) from .image_writer import squeeze_single_channel @@ -43,14 +44,6 @@ MM_PER_METRE = 1000.0 _UINT16_MAX = 65535 -def infer_depth_unit(dtype: np.dtype | type) -> str: - """Infer the physical unit of raw depth frames from their dtype. - - Floating-point frames are assumed to be in metres, integer frames in millimetres. - """ - return DEPTH_METER_UNIT if np.issubdtype(np.dtype(dtype), np.floating) else DEPTH_MILLIMETER_UNIT - - def _validate_log_quant_params(depth_min: float, shift: float) -> None: """Ensure ``log(depth_min + shift)`` is finite.""" if depth_min + shift <= 0: diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 4a55a362a..5c0b0f524 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -26,8 +26,8 @@ import pytest import torch from datasets import Dataset +from lerobot.configs.video import infer_depth_unit from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata -from lerobot.datasets.depth_utils import infer_depth_unit from lerobot.datasets.feature_utils import get_hf_features_from_features from lerobot.datasets.io_utils import flatten_dict, hf_transform_to_torch from lerobot.datasets.lerobot_dataset import LeRobotDataset