chore(infer_depth_unit): moving the depth unit inference utility in a more accessible location

This commit is contained in:
CarolinePascal
2026-07-01 19:38:30 +02:00
parent afa189fc72
commit bb066435bf
5 changed files with 16 additions and 10 deletions
+2
View File
@@ -43,6 +43,7 @@ from .video import (
VideoEncoderConfig,
depth_encoder_defaults,
encoder_config_from_video_info,
infer_depth_unit,
rgb_encoder_defaults,
)
@@ -72,6 +73,7 @@ __all__ = [
"depth_encoder_defaults",
# Factories
"encoder_config_from_video_info",
"infer_depth_unit",
# Constants
"DEFAULT_DEPTH_UNIT",
"DEPTH_METER_UNIT",
+11
View File
@@ -22,6 +22,8 @@ import logging
from dataclasses import dataclass, field
from typing import Any, ClassVar, Self
import numpy as np
from lerobot.utils.import_utils import require_package
logger = logging.getLogger(__name__)
@@ -65,6 +67,15 @@ DEPTH_METER_UNIT: str = "m"
DEPTH_MILLIMETER_UNIT: str = "mm"
DEFAULT_DEPTH_UNIT: str = DEPTH_MILLIMETER_UNIT
def infer_depth_unit(dtype: np.dtype | type) -> str:
"""Infer the physical unit of raw depth frames from their dtype.
Floating-point frames are assumed to be in metres, integer frames in millimetres.
"""
return DEPTH_METER_UNIT if np.issubdtype(np.dtype(dtype), np.floating) else DEPTH_MILLIMETER_UNIT
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
+1 -1
View File
@@ -36,12 +36,12 @@ from lerobot.configs import (
RGBEncoderConfig,
VideoEncoderConfig,
depth_encoder_defaults,
infer_depth_unit,
rgb_encoder_defaults,
)
from .compute_stats import compute_episode_stats
from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import infer_depth_unit
from .feature_utils import (
get_hf_features_from_features,
validate_episode_buffer,
+1 -8
View File
@@ -34,6 +34,7 @@ from lerobot.configs.video import (
DEPTH_METER_UNIT,
DEPTH_MILLIMETER_UNIT,
DEPTH_QMAX,
infer_depth_unit,
)
from .image_writer import squeeze_single_channel
@@ -43,14 +44,6 @@ MM_PER_METRE = 1000.0
_UINT16_MAX = 65535
def infer_depth_unit(dtype: np.dtype | type) -> str:
"""Infer the physical unit of raw depth frames from their dtype.
Floating-point frames are assumed to be in metres, integer frames in millimetres.
"""
return DEPTH_METER_UNIT if np.issubdtype(np.dtype(dtype), np.floating) else DEPTH_MILLIMETER_UNIT
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
"""Ensure ``log(depth_min + shift)`` is finite."""
if depth_min + shift <= 0:
+1 -1
View File
@@ -26,8 +26,8 @@ import pytest
import torch
from datasets import Dataset
from lerobot.configs.video import infer_depth_unit
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.depth_utils import infer_depth_unit
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import flatten_dict, hf_transform_to_torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset