mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-28 21:57:27 +00:00
feat(units): adding constants for depth frames units (m and mm)
This commit is contained in:
@@ -23,7 +23,13 @@ import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.configs import DepthEncoderConfig
|
||||
from lerobot.configs.video import DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MIN, DEPTH_QMAX
|
||||
from lerobot.configs.video import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEPTH_METER_UNIT,
|
||||
DEPTH_MILLIMETER_UNIT,
|
||||
DEPTH_QMAX,
|
||||
)
|
||||
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
|
||||
from tests.fixtures.constants import (
|
||||
@@ -51,7 +57,7 @@ class TestQuantizeDequantize:
|
||||
"""Numerical contract of ``quantize_depth`` / ``dequantize_depth``."""
|
||||
|
||||
@pytest.mark.parametrize("use_log", [False, True])
|
||||
@pytest.mark.parametrize("output_unit", ["m", "mm"])
|
||||
@pytest.mark.parametrize("output_unit", [DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT])
|
||||
@pytest.mark.parametrize("output_channel_last", [False, True])
|
||||
def test_roundtrip(self, use_log, output_unit, output_channel_last):
|
||||
"""quantize → dequantize recovers depth; layout and unit are honored."""
|
||||
@@ -69,7 +75,7 @@ class TestQuantizeDequantize:
|
||||
assert recovered.shape == expected_shape
|
||||
|
||||
recovered_m = recovered.astype(np.float32)
|
||||
if output_unit == "mm":
|
||||
if output_unit == DEPTH_MILLIMETER_UNIT:
|
||||
recovered_m = recovered_m / 1000.0
|
||||
recovered_2d = recovered_m[..., 0] if output_channel_last else recovered_m[0]
|
||||
|
||||
@@ -86,7 +92,7 @@ class TestQuantizeDequantize:
|
||||
np.testing.assert_allclose(recovered_2d, depth, atol=tol)
|
||||
|
||||
@pytest.mark.parametrize("use_log", [False, True])
|
||||
@pytest.mark.parametrize("output_unit", ["m", "mm"])
|
||||
@pytest.mark.parametrize("output_unit", [DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT])
|
||||
def test_numpy_torch_agree(self, use_log, output_unit):
|
||||
"""Batched torch path produces the same values as the numpy path."""
|
||||
batch_size = 3
|
||||
@@ -101,7 +107,7 @@ class TestQuantizeDequantize:
|
||||
assert out.shape == (batch_size, 1, H, W)
|
||||
# ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step.
|
||||
# ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop.
|
||||
atol = 1e-5 if output_unit == "m" else 1.0
|
||||
atol = 1e-5 if output_unit == DEPTH_METER_UNIT else 1.0
|
||||
np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -117,7 +123,7 @@ class TestQuantizeDequantize:
|
||||
def test_input_layouts_accepted(self, input_shape, output_shape):
|
||||
"""All documented input layouts decode to the channel-first default."""
|
||||
quantized = np.full(input_shape, DEPTH_QMAX // 2, dtype=np.uint16)
|
||||
out = dequantize_depth(quantized, output_unit="m", output_tensor=False)
|
||||
out = dequantize_depth(quantized, output_unit=DEPTH_METER_UNIT, output_tensor=False)
|
||||
assert out.shape == output_shape
|
||||
|
||||
def test_pyav_frame_roundtrip(self):
|
||||
@@ -126,7 +132,7 @@ class TestQuantizeDequantize:
|
||||
frame = quantize_depth(depth, use_log=False, video_backend="pyav")
|
||||
assert isinstance(frame, av.VideoFrame)
|
||||
|
||||
recovered = dequantize_depth(frame, use_log=False, output_unit="m", output_tensor=False)
|
||||
recovered = dequantize_depth(frame, use_log=False, output_unit=DEPTH_METER_UNIT, output_tensor=False)
|
||||
assert recovered.shape == (1, H, W)
|
||||
tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3
|
||||
np.testing.assert_allclose(recovered[0], depth, atol=tol)
|
||||
|
||||
Reference in New Issue
Block a user