From eec82264ef34dc44dd18dfb3f413ef55c756d017 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Tue, 23 Jun 2026 18:39:44 +0200 Subject: [PATCH] feat(units): adding constants for depth frames units (m and mm) --- src/lerobot/configs/default.py | 8 +++--- src/lerobot/configs/video.py | 5 +++- src/lerobot/datasets/depth_utils.py | 40 +++++++++++++++++++---------- tests/datasets/test_depth.py | 20 ++++++++++----- 4 files changed, 48 insertions(+), 25 deletions(-) diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index d40ec0b31..6b0273fa7 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -19,7 +19,7 @@ from dataclasses import dataclass, field from lerobot.transforms import ImageTransformsConfig from lerobot.utils.import_utils import get_safe_default_video_backend -from .video import DEFAULT_DEPTH_UNIT +from .video import DEFAULT_DEPTH_UNIT, DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT @dataclass @@ -46,8 +46,10 @@ class DatasetConfig: streaming: bool = False def __post_init__(self) -> None: - if self.depth_output_unit not in ("m", "mm"): - raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}") + if self.depth_output_unit not in (DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT): + raise ValueError( + f"depth_output_unit must be '{DEPTH_METER_UNIT}' or '{DEPTH_MILLIMETER_UNIT}', got {self.depth_output_unit!r}" + ) if self.episodes is not None: if any(ep < 0 for ep in self.episodes): raise ValueError( diff --git a/src/lerobot/configs/video.py b/src/lerobot/configs/video.py index 7542420f4..c393ad036 100644 --- a/src/lerobot/configs/video.py +++ b/src/lerobot/configs/video.py @@ -62,7 +62,10 @@ DEFAULT_DEPTH_MAX: float = 10.0 DEFAULT_DEPTH_SHIFT: float = 3.5 DEFAULT_DEPTH_USE_LOG: bool = True DEFAULT_DEPTH_PIX_FMT: str = "gray12le" -DEFAULT_DEPTH_UNIT = "mm" + +DEPTH_METER_UNIT: str = "m" +DEPTH_MILLIMETER_UNIT: str = "mm" +DEFAULT_DEPTH_UNIT: str = DEPTH_MILLIMETER_UNIT # Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.``. DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"}) diff --git a/src/lerobot/datasets/depth_utils.py b/src/lerobot/datasets/depth_utils.py index fbfe1d980..e3ab32982 100644 --- a/src/lerobot/datasets/depth_utils.py +++ b/src/lerobot/datasets/depth_utils.py @@ -31,6 +31,8 @@ from lerobot.configs.video import ( DEFAULT_DEPTH_PIX_FMT, DEFAULT_DEPTH_SHIFT, DEFAULT_DEPTH_USE_LOG, + DEPTH_METER_UNIT, + DEPTH_MILLIMETER_UNIT, DEPTH_QMAX, ) @@ -51,11 +53,13 @@ def _validate_log_quant_params(depth_min: float, shift: float) -> None: def _depth_input_to_float32_and_unit( depth: NDArray[np.integer] | NDArray[np.floating], - input_unit: Literal["auto", "m", "mm"], -) -> tuple[NDArray[np.float32], Literal["m", "mm"]]: + input_unit: Literal["auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT], +) -> tuple[NDArray[np.float32], Literal[DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT]]: """Convert depth to float32 in the chosen unit, and return the resolved unit.""" resolved_unit = ( - ("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit + (DEPTH_METER_UNIT if np.issubdtype(depth.dtype, np.floating) else DEPTH_MILLIMETER_UNIT) + if input_unit == "auto" + else input_unit ) return depth.astype(np.float32, order="K"), resolved_unit @@ -68,7 +72,7 @@ def quantize_depth( use_log: bool = DEFAULT_DEPTH_USE_LOG, pix_fmt: str = DEFAULT_DEPTH_PIX_FMT, video_backend: str | None = "pyav", - input_unit: Literal["auto", "m", "mm"] = "auto", + input_unit: Literal["auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT] = "auto", ) -> NDArray[np.uint16] | av.VideoFrame: """Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``). @@ -106,8 +110,10 @@ def quantize_depth( ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``. ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``. """ - if input_unit not in ("auto", "m", "mm"): - raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}") + if input_unit not in ("auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT): + raise ValueError( + f"input_unit must be 'auto', '{DEPTH_METER_UNIT}', or '{DEPTH_MILLIMETER_UNIT}', got {input_unit!r}" + ) if isinstance(depth, torch.Tensor): depth = depth.detach().cpu().numpy() @@ -119,9 +125,13 @@ def quantize_depth( depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit) # Convert depth_min, depth_max, and shift to the resolved input unit. - depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE) - depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE) - shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE) + depth_min_u = ( + np.float32(depth_min) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_min * _MM_PER_METRE) + ) + depth_max_u = ( + np.float32(depth_max) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_max * _MM_PER_METRE) + ) + shift_u = np.float32(shift) if resolved_unit == DEPTH_METER_UNIT else np.float32(shift * _MM_PER_METRE) # Normalization and quantization is performed in the resolved input unit. if use_log: @@ -149,7 +159,7 @@ def dequantize_depth( shift: float = DEFAULT_DEPTH_SHIFT, use_log: bool = DEFAULT_DEPTH_USE_LOG, pix_fmt: str = DEFAULT_DEPTH_PIX_FMT, - output_unit: Literal["m", "mm"] = "mm", + output_unit: Literal[DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT] = DEPTH_MILLIMETER_UNIT, output_tensor: bool = True, output_channel_last: bool = False, ) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor: @@ -184,8 +194,10 @@ def dequantize_depth( ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``. ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``. """ - if output_unit not in ("m", "mm"): - raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}") + if output_unit not in (DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT): + raise ValueError( + f"output_unit must be '{DEPTH_METER_UNIT}' or '{DEPTH_MILLIMETER_UNIT}', got {output_unit!r}" + ) if use_log: _validate_log_quant_params(depth_min, shift) @@ -219,7 +231,7 @@ def dequantize_depth( buf.clamp_(depth_min_m, depth_max_m) buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3) - if output_unit == "m": + if output_unit == DEPTH_METER_UNIT: return buf if output_tensor else buf.cpu().numpy() # mm path: round + clamp in float32, skipping the uint16 round-trip @@ -244,7 +256,7 @@ def dequantize_depth( np.clip(buf, depth_min_m, depth_max_m, out=buf) buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3) - if output_unit == "m": + if output_unit == DEPTH_METER_UNIT: return torch.from_numpy(buf) if output_tensor else buf np.multiply(buf, _MM_PER_METRE, out=buf) diff --git a/tests/datasets/test_depth.py b/tests/datasets/test_depth.py index 1760b2e63..a075fa6b5 100644 --- a/tests/datasets/test_depth.py +++ b/tests/datasets/test_depth.py @@ -23,7 +23,13 @@ import PIL.Image import torch from lerobot.configs import DepthEncoderConfig -from lerobot.configs.video import DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MIN, DEPTH_QMAX +from lerobot.configs.video import ( + DEFAULT_DEPTH_MAX, + DEFAULT_DEPTH_MIN, + DEPTH_METER_UNIT, + DEPTH_MILLIMETER_UNIT, + DEPTH_QMAX, +) from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth from lerobot.datasets.image_writer import image_array_to_pil_image, write_image from tests.fixtures.constants import ( @@ -51,7 +57,7 @@ class TestQuantizeDequantize: """Numerical contract of ``quantize_depth`` / ``dequantize_depth``.""" @pytest.mark.parametrize("use_log", [False, True]) - @pytest.mark.parametrize("output_unit", ["m", "mm"]) + @pytest.mark.parametrize("output_unit", [DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT]) @pytest.mark.parametrize("output_channel_last", [False, True]) def test_roundtrip(self, use_log, output_unit, output_channel_last): """quantize → dequantize recovers depth; layout and unit are honored.""" @@ -69,7 +75,7 @@ class TestQuantizeDequantize: assert recovered.shape == expected_shape recovered_m = recovered.astype(np.float32) - if output_unit == "mm": + if output_unit == DEPTH_MILLIMETER_UNIT: recovered_m = recovered_m / 1000.0 recovered_2d = recovered_m[..., 0] if output_channel_last else recovered_m[0] @@ -86,7 +92,7 @@ class TestQuantizeDequantize: np.testing.assert_allclose(recovered_2d, depth, atol=tol) @pytest.mark.parametrize("use_log", [False, True]) - @pytest.mark.parametrize("output_unit", ["m", "mm"]) + @pytest.mark.parametrize("output_unit", [DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT]) def test_numpy_torch_agree(self, use_log, output_unit): """Batched torch path produces the same values as the numpy path.""" batch_size = 3 @@ -101,7 +107,7 @@ class TestQuantizeDequantize: assert out.shape == (batch_size, 1, H, W) # ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step. # ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop. - atol = 1e-5 if output_unit == "m" else 1.0 + atol = 1e-5 if output_unit == DEPTH_METER_UNIT else 1.0 np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol) @pytest.mark.parametrize( @@ -117,7 +123,7 @@ class TestQuantizeDequantize: def test_input_layouts_accepted(self, input_shape, output_shape): """All documented input layouts decode to the channel-first default.""" quantized = np.full(input_shape, DEPTH_QMAX // 2, dtype=np.uint16) - out = dequantize_depth(quantized, output_unit="m", output_tensor=False) + out = dequantize_depth(quantized, output_unit=DEPTH_METER_UNIT, output_tensor=False) assert out.shape == output_shape def test_pyav_frame_roundtrip(self): @@ -126,7 +132,7 @@ class TestQuantizeDequantize: frame = quantize_depth(depth, use_log=False, video_backend="pyav") assert isinstance(frame, av.VideoFrame) - recovered = dequantize_depth(frame, use_log=False, output_unit="m", output_tensor=False) + recovered = dequantize_depth(frame, use_log=False, output_unit=DEPTH_METER_UNIT, output_tensor=False) assert recovered.shape == (1, H, W) tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3 np.testing.assert_allclose(recovered[0], depth, atol=tol)