feat(units): adding constants for depth frames units (m and mm)

This commit is contained in:
CarolinePascal
2026-06-23 18:39:44 +02:00
parent 235356730e
commit eec82264ef
4 changed files with 48 additions and 25 deletions
+13 -7
View File
@@ -23,7 +23,13 @@ import PIL.Image
import torch
from lerobot.configs import DepthEncoderConfig
from lerobot.configs.video import DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MIN, DEPTH_QMAX
from lerobot.configs.video import (
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_MIN,
DEPTH_METER_UNIT,
DEPTH_MILLIMETER_UNIT,
DEPTH_QMAX,
)
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
from tests.fixtures.constants import (
@@ -51,7 +57,7 @@ class TestQuantizeDequantize:
"""Numerical contract of ``quantize_depth`` / ``dequantize_depth``."""
@pytest.mark.parametrize("use_log", [False, True])
@pytest.mark.parametrize("output_unit", ["m", "mm"])
@pytest.mark.parametrize("output_unit", [DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT])
@pytest.mark.parametrize("output_channel_last", [False, True])
def test_roundtrip(self, use_log, output_unit, output_channel_last):
"""quantize → dequantize recovers depth; layout and unit are honored."""
@@ -69,7 +75,7 @@ class TestQuantizeDequantize:
assert recovered.shape == expected_shape
recovered_m = recovered.astype(np.float32)
if output_unit == "mm":
if output_unit == DEPTH_MILLIMETER_UNIT:
recovered_m = recovered_m / 1000.0
recovered_2d = recovered_m[..., 0] if output_channel_last else recovered_m[0]
@@ -86,7 +92,7 @@ class TestQuantizeDequantize:
np.testing.assert_allclose(recovered_2d, depth, atol=tol)
@pytest.mark.parametrize("use_log", [False, True])
@pytest.mark.parametrize("output_unit", ["m", "mm"])
@pytest.mark.parametrize("output_unit", [DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT])
def test_numpy_torch_agree(self, use_log, output_unit):
"""Batched torch path produces the same values as the numpy path."""
batch_size = 3
@@ -101,7 +107,7 @@ class TestQuantizeDequantize:
assert out.shape == (batch_size, 1, H, W)
# ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step.
# ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop.
atol = 1e-5 if output_unit == "m" else 1.0
atol = 1e-5 if output_unit == DEPTH_METER_UNIT else 1.0
np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol)
@pytest.mark.parametrize(
@@ -117,7 +123,7 @@ class TestQuantizeDequantize:
def test_input_layouts_accepted(self, input_shape, output_shape):
"""All documented input layouts decode to the channel-first default."""
quantized = np.full(input_shape, DEPTH_QMAX // 2, dtype=np.uint16)
out = dequantize_depth(quantized, output_unit="m", output_tensor=False)
out = dequantize_depth(quantized, output_unit=DEPTH_METER_UNIT, output_tensor=False)
assert out.shape == output_shape
def test_pyav_frame_roundtrip(self):
@@ -126,7 +132,7 @@ class TestQuantizeDequantize:
frame = quantize_depth(depth, use_log=False, video_backend="pyav")
assert isinstance(frame, av.VideoFrame)
recovered = dequantize_depth(frame, use_log=False, output_unit="m", output_tensor=False)
recovered = dequantize_depth(frame, use_log=False, output_unit=DEPTH_METER_UNIT, output_tensor=False)
assert recovered.shape == (1, H, W)
tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3
np.testing.assert_allclose(recovered[0], depth, atol=tol)