feat(units): adding constants for depth frames units (m and mm)

This commit is contained in:
CarolinePascal
2026-06-23 18:39:44 +02:00
parent 235356730e
commit eec82264ef
4 changed files with 48 additions and 25 deletions
+5 -3
View File
@@ -19,7 +19,7 @@ from dataclasses import dataclass, field
from lerobot.transforms import ImageTransformsConfig
from lerobot.utils.import_utils import get_safe_default_video_backend
from .video import DEFAULT_DEPTH_UNIT
from .video import DEFAULT_DEPTH_UNIT, DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT
@dataclass
@@ -46,8 +46,10 @@ class DatasetConfig:
streaming: bool = False
def __post_init__(self) -> None:
if self.depth_output_unit not in ("m", "mm"):
raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}")
if self.depth_output_unit not in (DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT):
raise ValueError(
f"depth_output_unit must be '{DEPTH_METER_UNIT}' or '{DEPTH_MILLIMETER_UNIT}', got {self.depth_output_unit!r}"
)
if self.episodes is not None:
if any(ep < 0 for ep in self.episodes):
raise ValueError(
+4 -1
View File
@@ -62,7 +62,10 @@ DEFAULT_DEPTH_MAX: float = 10.0
DEFAULT_DEPTH_SHIFT: float = 3.5
DEFAULT_DEPTH_USE_LOG: bool = True
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
DEFAULT_DEPTH_UNIT = "mm"
DEPTH_METER_UNIT: str = "m"
DEPTH_MILLIMETER_UNIT: str = "mm"
DEFAULT_DEPTH_UNIT: str = DEPTH_MILLIMETER_UNIT
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
+26 -14
View File
@@ -31,6 +31,8 @@ from lerobot.configs.video import (
DEFAULT_DEPTH_PIX_FMT,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
DEPTH_METER_UNIT,
DEPTH_MILLIMETER_UNIT,
DEPTH_QMAX,
)
@@ -51,11 +53,13 @@ def _validate_log_quant_params(depth_min: float, shift: float) -> None:
def _depth_input_to_float32_and_unit(
depth: NDArray[np.integer] | NDArray[np.floating],
input_unit: Literal["auto", "m", "mm"],
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
input_unit: Literal["auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT],
) -> tuple[NDArray[np.float32], Literal[DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT]]:
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
resolved_unit = (
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
(DEPTH_METER_UNIT if np.issubdtype(depth.dtype, np.floating) else DEPTH_MILLIMETER_UNIT)
if input_unit == "auto"
else input_unit
)
return depth.astype(np.float32, order="K"), resolved_unit
@@ -68,7 +72,7 @@ def quantize_depth(
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
video_backend: str | None = "pyav",
input_unit: Literal["auto", "m", "mm"] = "auto",
input_unit: Literal["auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT] = "auto",
) -> NDArray[np.uint16] | av.VideoFrame:
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
@@ -106,8 +110,10 @@ def quantize_depth(
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
"""
if input_unit not in ("auto", "m", "mm"):
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
if input_unit not in ("auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT):
raise ValueError(
f"input_unit must be 'auto', '{DEPTH_METER_UNIT}', or '{DEPTH_MILLIMETER_UNIT}', got {input_unit!r}"
)
if isinstance(depth, torch.Tensor):
depth = depth.detach().cpu().numpy()
@@ -119,9 +125,13 @@ def quantize_depth(
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
# Convert depth_min, depth_max, and shift to the resolved input unit.
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
depth_min_u = (
np.float32(depth_min) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_min * _MM_PER_METRE)
)
depth_max_u = (
np.float32(depth_max) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_max * _MM_PER_METRE)
)
shift_u = np.float32(shift) if resolved_unit == DEPTH_METER_UNIT else np.float32(shift * _MM_PER_METRE)
# Normalization and quantization is performed in the resolved input unit.
if use_log:
@@ -149,7 +159,7 @@ def dequantize_depth(
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
output_unit: Literal["m", "mm"] = "mm",
output_unit: Literal[DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT] = DEPTH_MILLIMETER_UNIT,
output_tensor: bool = True,
output_channel_last: bool = False,
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
@@ -184,8 +194,10 @@ def dequantize_depth(
ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``.
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
"""
if output_unit not in ("m", "mm"):
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
if output_unit not in (DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT):
raise ValueError(
f"output_unit must be '{DEPTH_METER_UNIT}' or '{DEPTH_MILLIMETER_UNIT}', got {output_unit!r}"
)
if use_log:
_validate_log_quant_params(depth_min, shift)
@@ -219,7 +231,7 @@ def dequantize_depth(
buf.clamp_(depth_min_m, depth_max_m)
buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3)
if output_unit == "m":
if output_unit == DEPTH_METER_UNIT:
return buf if output_tensor else buf.cpu().numpy()
# mm path: round + clamp in float32, skipping the uint16 round-trip
@@ -244,7 +256,7 @@ def dequantize_depth(
np.clip(buf, depth_min_m, depth_max_m, out=buf)
buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3)
if output_unit == "m":
if output_unit == DEPTH_METER_UNIT:
return torch.from_numpy(buf) if output_tensor else buf
np.multiply(buf, _MM_PER_METRE, out=buf)
+13 -7
View File
@@ -23,7 +23,13 @@ import PIL.Image
import torch
from lerobot.configs import DepthEncoderConfig
from lerobot.configs.video import DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MIN, DEPTH_QMAX
from lerobot.configs.video import (
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_MIN,
DEPTH_METER_UNIT,
DEPTH_MILLIMETER_UNIT,
DEPTH_QMAX,
)
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
from tests.fixtures.constants import (
@@ -51,7 +57,7 @@ class TestQuantizeDequantize:
"""Numerical contract of ``quantize_depth`` / ``dequantize_depth``."""
@pytest.mark.parametrize("use_log", [False, True])
@pytest.mark.parametrize("output_unit", ["m", "mm"])
@pytest.mark.parametrize("output_unit", [DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT])
@pytest.mark.parametrize("output_channel_last", [False, True])
def test_roundtrip(self, use_log, output_unit, output_channel_last):
"""quantize → dequantize recovers depth; layout and unit are honored."""
@@ -69,7 +75,7 @@ class TestQuantizeDequantize:
assert recovered.shape == expected_shape
recovered_m = recovered.astype(np.float32)
if output_unit == "mm":
if output_unit == DEPTH_MILLIMETER_UNIT:
recovered_m = recovered_m / 1000.0
recovered_2d = recovered_m[..., 0] if output_channel_last else recovered_m[0]
@@ -86,7 +92,7 @@ class TestQuantizeDequantize:
np.testing.assert_allclose(recovered_2d, depth, atol=tol)
@pytest.mark.parametrize("use_log", [False, True])
@pytest.mark.parametrize("output_unit", ["m", "mm"])
@pytest.mark.parametrize("output_unit", [DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT])
def test_numpy_torch_agree(self, use_log, output_unit):
"""Batched torch path produces the same values as the numpy path."""
batch_size = 3
@@ -101,7 +107,7 @@ class TestQuantizeDequantize:
assert out.shape == (batch_size, 1, H, W)
# ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step.
# ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop.
atol = 1e-5 if output_unit == "m" else 1.0
atol = 1e-5 if output_unit == DEPTH_METER_UNIT else 1.0
np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol)
@pytest.mark.parametrize(
@@ -117,7 +123,7 @@ class TestQuantizeDequantize:
def test_input_layouts_accepted(self, input_shape, output_shape):
"""All documented input layouts decode to the channel-first default."""
quantized = np.full(input_shape, DEPTH_QMAX // 2, dtype=np.uint16)
out = dequantize_depth(quantized, output_unit="m", output_tensor=False)
out = dequantize_depth(quantized, output_unit=DEPTH_METER_UNIT, output_tensor=False)
assert out.shape == output_shape
def test_pyav_frame_roundtrip(self):
@@ -126,7 +132,7 @@ class TestQuantizeDequantize:
frame = quantize_depth(depth, use_log=False, video_backend="pyav")
assert isinstance(frame, av.VideoFrame)
recovered = dequantize_depth(frame, use_log=False, output_unit="m", output_tensor=False)
recovered = dequantize_depth(frame, use_log=False, output_unit=DEPTH_METER_UNIT, output_tensor=False)
assert recovered.shape == (1, H, W)
tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3
np.testing.assert_allclose(recovered[0], depth, atol=tol)