From e87933302d40a62c4f1871f6814593c97be68768 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 21 May 2026 14:25:42 +0200 Subject: [PATCH] feat(depth shape): ensuring depth maps shape is always including the channel --- .../cameras/realsense/camera_realsense.py | 10 ++++++---- src/lerobot/datasets/depth_utils.py | 8 ++++++++ src/lerobot/datasets/feature_utils.py | 18 +++++------------- src/lerobot/datasets/video_utils.py | 4 ++-- tests/datasets/test_dataset_metadata.py | 16 +++++++++------- 5 files changed, 30 insertions(+), 26 deletions(-) diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index 19af51bc9..4f0c7a592 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -332,8 +332,8 @@ class RealSenseCamera(Camera): from the camera hardware via the RealSense pipeline. Returns: - np.ndarray: The depth map as a NumPy array (height, width) - of type `np.uint16` (raw depth values in millimeters) and rotation. + np.ndarray: The depth map as a NumPy array (height, width, 1) + of type `np.uint16` (raw depth values in millimeters). Raises: DeviceNotConnectedError: If the camera is not connected. @@ -486,6 +486,8 @@ class RealSenseCamera(Camera): depth_frame_raw = frame.get_depth_frame() depth_frame = np.asanyarray(depth_frame_raw.get_data()) processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True) + if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1) + processed_depth_frame = processed_depth_frame[..., np.newaxis] capture_time = time.perf_counter() @@ -614,7 +616,7 @@ class RealSenseCamera(Camera): """Read the latest depth frame asynchronously, in metric meters. Mirrors :meth:`async_read` but returns the depth stream rather than the - color stream. Output is ``np.uint16`` of shape ``(H, W)``. + color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``. Raises: DeviceNotConnectedError: If the camera is not connected. @@ -645,7 +647,7 @@ class RealSenseCamera(Camera): """Return the most recent depth frame in metric meters (peeking). Non-blocking counterpart of :meth:`read_latest` for the depth stream. - Output is ``np.float32`` of shape ``(H, W)`` in meters. + Output is ``np.uint16`` of shape ``(H, W, 1)`` in millimeters. Raises: DeviceNotConnectedError: If the camera is not connected. diff --git a/src/lerobot/datasets/depth_utils.py b/src/lerobot/datasets/depth_utils.py index cbded58de..e7db76398 100644 --- a/src/lerobot/datasets/depth_utils.py +++ b/src/lerobot/datasets/depth_utils.py @@ -113,6 +113,10 @@ def quantize_depth( if isinstance(depth, torch.Tensor): depth = depth.detach().cpu().numpy() + # Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W) + if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1): + depth = depth.squeeze() + depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit) # Convert depth_min, depth_max, and shift to the resolved input unit. @@ -192,6 +196,10 @@ def dequantize_depth( else: depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False) + + # Add single-channel dim: (H, W) → (H, W, 1) + if depth_m.ndim == 2: + depth_m = depth_m[..., np.newaxis] # Return depth as float32 meters. if output_unit == "m": diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py index 06467ac3a..9df670913 100644 --- a/src/lerobot/datasets/feature_utils.py +++ b/src/lerobot/datasets/feature_utils.py @@ -321,7 +321,7 @@ def validate_feature_image_or_video( Args: name (str): The name of the feature. - expected_shape (list[str]): The expected shape (C, H, W). + expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C). value: The image data to validate. Returns: @@ -330,20 +330,12 @@ def validate_feature_image_or_video( # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): - actual_shape = tuple(value.shape) - expected = tuple(expected_shape) - if len(expected) == 2: - # Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1) - h, w = expected - valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)} - if not valid: - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n" - elif len(expected) == 3: - c, h, w = expected - if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + actual_shape = value.shape + c, h, w = expected_shape + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" else: - error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n" + error_message += f"The feature '{name}' has an unsupported expected_shape '{expected_shape}'.\n" elif isinstance(value, PILImage.Image): pass else: diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 6330ae447..1a1233cbc 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -573,7 +573,7 @@ class _CameraEncoderThread(threading.Thread): # Ensure HWC (RGB or depth) uint8 (RGB only) numpy array if isinstance(frame_data, np.ndarray): - if frame_data.ndim == 3 and frame_data.shape[0] == 3: + if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3): # CHW -> HWC frame_data = frame_data.transpose(1, 2, 0) if not self.is_depth and frame_data.dtype != np.uint8: @@ -699,7 +699,7 @@ class StreamingVideoEncoder: Args: video_keys: List of video feature keys (e.g. ["observation.images.laptop"]) - depth_video_keys: List of video feature keys that carry depth maps (e.g. ["observation.depth.laptop"]) + depth_video_keys: List of video feature keys that carry depth maps (e.g. ["observation.images.laptop_depth"]) temp_dir: Base directory for temporary MP4 files """ if self._episode_active: diff --git a/tests/datasets/test_dataset_metadata.py b/tests/datasets/test_dataset_metadata.py index 9f417af5e..6d8b6f06f 100644 --- a/tests/datasets/test_dataset_metadata.py +++ b/tests/datasets/test_dataset_metadata.py @@ -55,10 +55,10 @@ IMAGE_FEATURES = { DEPTH_FEATURES = { **SIMPLE_FEATURES, - "observation.depth.laptop": { + "observation.images.laptop_depth": { "dtype": "video", - "shape": (64, 96), - "names": ["height", "width"], + "shape": (64, 96, 1), + "names": ["height", "width", "channels"], "info": {"video.is_depth_map": True}, }, } @@ -69,11 +69,13 @@ def _make_dummy_stats(features: dict) -> dict: stats = {} for key, ft in features.items(): if ft["dtype"] in ("image", "video"): + channels = ft["shape"][-1] + stat_shape = (channels, 1, 1) stats[key] = { - "max": np.ones((3, 1, 1), dtype=np.float32), - "mean": np.full((3, 1, 1), 0.5, dtype=np.float32), - "min": np.zeros((3, 1, 1), dtype=np.float32), - "std": np.full((3, 1, 1), 0.25, dtype=np.float32), + "max": np.ones(stat_shape, dtype=np.float32), + "mean": np.full(stat_shape, 0.5, dtype=np.float32), + "min": np.zeros(stat_shape, dtype=np.float32), + "std": np.full(stat_shape, 0.25, dtype=np.float32), "count": np.array([5]), } elif ft["dtype"] in ("float32", "float64", "int64"):