"""Tests for the depth-integration feature. Covers: - ``depth_utils`` quantize/dequantize round-trips and backend agreement. - Image-writer support for single-channel depth. - Hardware-feature → depth flag routing. - Feature-to-file-format routing through the dataset writer. Depth metadata detection on ``LeRobotDatasetMetadata.depth_keys`` lives in ``test_dataset_metadata.py``. Depth video encoding/decoding lives in ``test_video_encoding.py``. """ from pathlib import Path import av import numpy as np import PIL.Image import pytest import torch from lerobot.configs import DepthEncoderConfig from lerobot.configs.video import DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MIN, DEPTH_QMAX from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth from lerobot.datasets.image_writer import image_array_to_pil_image, write_image from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, DUMMY_CAMERA_FEATURES_WITH_DEPTH, DUMMY_CHW, DUMMY_DEPTH_CAMERA_FEATURES, DUMMY_REPO_ID, ) from tests.fixtures.dataset_factories import add_frames _, H, W = DUMMY_CHW def _depth_metres_ramp() -> np.ndarray: """Linearly-spaced float32 depth in metres covering the default range.""" return np.linspace(DEFAULT_DEPTH_MIN, DEFAULT_DEPTH_MAX, H * W, dtype=np.float32).reshape(H, W) # ── 1. Quantize / dequantize round-trips ────────────────────────────── class TestQuantizeDequantize: """Numerical contract of ``quantize_depth`` / ``dequantize_depth``.""" @pytest.mark.parametrize("use_log", [False, True]) @pytest.mark.parametrize("output_unit", ["m", "mm"]) @pytest.mark.parametrize("output_channel_last", [False, True]) def test_roundtrip(self, use_log, output_unit, output_channel_last): """quantize → dequantize recovers depth; layout and unit are honored.""" depth = _depth_metres_ramp() quantized = quantize_depth(depth, use_log=use_log, video_backend=None) recovered = dequantize_depth( quantized, use_log=use_log, output_unit=output_unit, output_tensor=False, output_channel_last=output_channel_last, ) expected_shape = (H, W, 1) if output_channel_last else (1, H, W) assert recovered.shape == expected_shape recovered_m = recovered.astype(np.float32) if output_unit == "mm": recovered_m = recovered_m / 1000.0 recovered_2d = recovered_m[..., 0] if output_channel_last else recovered_m[0] if use_log: # Log mode: tighter near-range error than far-range (the whole point). near = depth < 1.0 far = depth > 8.0 err_near = np.abs(recovered_2d[near] - depth[near]) err_far = np.abs(recovered_2d[far] - depth[far]) assert err_near.mean() < err_far.mean() else: # Linear mode: bounded by quant step + 1 mm of unit-conversion rounding. tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3 np.testing.assert_allclose(recovered_2d, depth, atol=tol) @pytest.mark.parametrize("use_log", [False, True]) @pytest.mark.parametrize("output_unit", ["m", "mm"]) def test_numpy_torch_agree(self, use_log, output_unit): """Batched torch path produces the same values as the numpy path.""" T = 3 per_frame = np.linspace(0, DEPTH_QMAX, H * W, dtype=np.uint16).reshape(H, W) batch_np = np.broadcast_to(per_frame[None, None, ...], (T, 1, H, W)).copy() batch_t = torch.from_numpy(batch_np.astype(np.int32)) # torch.uint16 support is patchy. ref = dequantize_depth(batch_np, use_log=use_log, output_unit=output_unit, output_tensor=False) out = dequantize_depth(batch_t, use_log=use_log, output_unit=output_unit, output_tensor=True) assert isinstance(out, torch.Tensor) assert out.shape == (T, 1, H, W) # ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step. # ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop. atol = 1e-5 if output_unit == "m" else 1.0 np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol) @pytest.mark.parametrize( "input_shape,output_shape", [ ((H, W), (1, H, W)), ((1, H, W), (1, H, W)), ((H, W, 1), (1, H, W)), ((3, 1, H, W), (3, 1, H, W)), ((3, H, W, 1), (3, 1, H, W)), ], ) def test_input_layouts_accepted(self, input_shape, output_shape): """All documented input layouts decode to the channel-first default.""" quantized = np.full(input_shape, DEPTH_QMAX // 2, dtype=np.uint16) out = dequantize_depth(quantized, output_unit="m", output_tensor=False) assert out.shape == output_shape def test_pyav_frame_roundtrip(self): """quantize → av.VideoFrame → dequantize works.""" depth = _depth_metres_ramp() frame = quantize_depth(depth, use_log=False, video_backend="pyav") assert isinstance(frame, av.VideoFrame) recovered = dequantize_depth(frame, use_log=False, output_unit="m", output_tensor=False) assert recovered.shape == (1, H, W) tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3 np.testing.assert_allclose(recovered[0], depth, atol=tol) def test_invalid_log_params_raises(self): with pytest.raises(ValueError, match=r"depth_min \+ shift must be positive"): quantize_depth(_depth_metres_ramp(), depth_min=1.0, shift=-2.0, use_log=True, video_backend=None) # ── 2. Image writer depth support ───────────────────────────────────── class TestImageWriterDepth: """``image_array_to_pil_image`` and ``write_image`` for depth maps.""" @pytest.mark.parametrize("dtype,expected_mode", [(np.uint16, "I;16"), (np.float32, "F")]) @pytest.mark.parametrize("shape", [(H, W), (H, W, 1), (1, H, W)]) def test_pil_depth_modes_and_squeeze(self, dtype, expected_mode, shape): """Single-channel depth converts to PIL with the right mode and (W, H) size.""" arr = np.zeros(shape, dtype=dtype) img = image_array_to_pil_image(arr) assert img.mode == expected_mode assert img.size == (W, H) def test_write_image_tiff_roundtrip(self, tmp_path): """uint16 depth round-trips through .tiff.""" arr = np.arange(H * W, dtype=np.uint16).reshape(H, W) fpath = tmp_path / "depth.tiff" write_image(arr, fpath) with PIL.Image.open(fpath) as loaded: recovered = np.array(loaded) np.testing.assert_array_equal(recovered, arr) # ── 3. Hardware-feature → depth flag ────────────────────────────────── class TestHwToDatasetFeaturesDepth: """``hw_to_dataset_features`` flags single-channel cameras as depth.""" @pytest.mark.parametrize("channels,is_depth", [(1, True), (3, False)]) def test_depth_marker_by_channels(self, channels, is_depth): from lerobot.utils.feature_utils import hw_to_dataset_features features = hw_to_dataset_features({"cam": (480, 640, channels)}, prefix="observation") assert features["observation.images.cam"]["info"]["is_depth_map"] is is_depth def test_invalid_channel_count_raises(self): from lerobot.utils.feature_utils import hw_to_dataset_features with pytest.raises(ValueError, match="Expected a 3-tuple"): hw_to_dataset_features({"cam": (480, 640, 2)}, prefix="observation") # ── 4. Feature-to-file-format routing ──────────────────────────────── # Keys derived from DUMMY_CAMERA_FEATURES_WITH_DEPTH; pick one RGB and the depth camera. RGB_KEY = next(iter(DUMMY_CAMERA_FEATURES)) DEPTH_KEY = next(iter(DUMMY_DEPTH_CAMERA_FEATURES)) class TestFeatureFileRouting: """Depth vs RGB features route to the correct file format.""" NUM_FRAMES = 5 def test_image_mode_depth_tiff_rgb_png(self, tmp_path, features_factory): """Without video encoding: depth → .tiff, RGB → .png.""" from lerobot.datasets.lerobot_dataset import LeRobotDataset features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=False) dataset = LeRobotDataset.create( repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=features, root=tmp_path / "ds", use_videos=False, ) add_frames(dataset, num_frames=self.NUM_FRAMES) buf = dataset.writer.episode_buffer assert all(Path(p).suffix == ".tiff" for p in buf[DEPTH_KEY]) assert all(Path(p).suffix == ".png" for p in buf[RGB_KEY]) dataset.save_episode() dataset.finalize() def test_video_mode_depth_uses_depth_encoder(self, tmp_path, features_factory): """With streaming video encoding: depth → DepthEncoderConfig, RGB does not.""" from lerobot.datasets.lerobot_dataset import LeRobotDataset features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=True) dataset = LeRobotDataset.create( repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=features, root=tmp_path / "ds", use_videos=True, streaming_encoding=True, ) add_frames(dataset, num_frames=self.NUM_FRAMES) encoder = dataset.writer._streaming_encoder assert encoder is not None assert isinstance(encoder._threads[DEPTH_KEY].video_encoder, DepthEncoderConfig) assert not isinstance(encoder._threads[RGB_KEY].video_encoder, DepthEncoderConfig) dataset.save_episode() dataset.finalize()