feat(depth shape): ensuring depth maps shape is always including the channel

This commit is contained in:
CarolinePascal
2026-05-21 14:25:42 +02:00
parent 3cf5e3c8cb
commit e87933302d
5 changed files with 30 additions and 26 deletions
@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
from the camera hardware via the RealSense pipeline.
Returns:
np.ndarray: The depth map as a NumPy array (height, width)
of type `np.uint16` (raw depth values in millimeters) and rotation.
np.ndarray: The depth map as a NumPy array (height, width, 1)
of type `np.uint16` (raw depth values in millimeters).
Raises:
DeviceNotConnectedError: If the camera is not connected.
@@ -486,6 +486,8 @@ class RealSenseCamera(Camera):
depth_frame_raw = frame.get_depth_frame()
depth_frame = np.asanyarray(depth_frame_raw.get_data())
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
processed_depth_frame = processed_depth_frame[..., np.newaxis]
capture_time = time.perf_counter()
@@ -614,7 +616,7 @@ class RealSenseCamera(Camera):
"""Read the latest depth frame asynchronously, in metric meters.
Mirrors :meth:`async_read` but returns the depth stream rather than the
color stream. Output is ``np.uint16`` of shape ``(H, W)``.
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``.
Raises:
DeviceNotConnectedError: If the camera is not connected.
@@ -645,7 +647,7 @@ class RealSenseCamera(Camera):
"""Return the most recent depth frame in metric meters (peeking).
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
Output is ``np.float32`` of shape ``(H, W)`` in meters.
Output is ``np.uint16`` of shape ``(H, W, 1)`` in millimeters.
Raises:
DeviceNotConnectedError: If the camera is not connected.
+8
View File
@@ -113,6 +113,10 @@ def quantize_depth(
if isinstance(depth, torch.Tensor):
depth = depth.detach().cpu().numpy()
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
depth = depth.squeeze()
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
# Convert depth_min, depth_max, and shift to the resolved input unit.
@@ -192,6 +196,10 @@ def dequantize_depth(
else:
depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m
depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False)
# Add single-channel dim: (H, W) → (H, W, 1)
if depth_m.ndim == 2:
depth_m = depth_m[..., np.newaxis]
# Return depth as float32 meters.
if output_unit == "m":
+5 -13
View File
@@ -321,7 +321,7 @@ def validate_feature_image_or_video(
Args:
name (str): The name of the feature.
expected_shape (list[str]): The expected shape (C, H, W).
expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
value: The image data to validate.
Returns:
@@ -330,20 +330,12 @@ def validate_feature_image_or_video(
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = tuple(value.shape)
expected = tuple(expected_shape)
if len(expected) == 2:
# Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1)
h, w = expected
valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)}
if not valid:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n"
elif len(expected) == 3:
c, h, w = expected
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
else:
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n"
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected_shape}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:
+2 -2
View File
@@ -573,7 +573,7 @@ class _CameraEncoderThread(threading.Thread):
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
if isinstance(frame_data, np.ndarray):
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
# CHW -> HWC
frame_data = frame_data.transpose(1, 2, 0)
if not self.is_depth and frame_data.dtype != np.uint8:
@@ -699,7 +699,7 @@ class StreamingVideoEncoder:
Args:
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
depth_video_keys: List of video feature keys that carry depth maps (e.g. ["observation.depth.laptop"])
depth_video_keys: List of video feature keys that carry depth maps (e.g. ["observation.images.laptop_depth"])
temp_dir: Base directory for temporary MP4 files
"""
if self._episode_active:
+9 -7
View File
@@ -55,10 +55,10 @@ IMAGE_FEATURES = {
DEPTH_FEATURES = {
**SIMPLE_FEATURES,
"observation.depth.laptop": {
"observation.images.laptop_depth": {
"dtype": "video",
"shape": (64, 96),
"names": ["height", "width"],
"shape": (64, 96, 1),
"names": ["height", "width", "channels"],
"info": {"video.is_depth_map": True},
},
}
@@ -69,11 +69,13 @@ def _make_dummy_stats(features: dict) -> dict:
stats = {}
for key, ft in features.items():
if ft["dtype"] in ("image", "video"):
channels = ft["shape"][-1]
stat_shape = (channels, 1, 1)
stats[key] = {
"max": np.ones((3, 1, 1), dtype=np.float32),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
"min": np.zeros((3, 1, 1), dtype=np.float32),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
"max": np.ones(stat_shape, dtype=np.float32),
"mean": np.full(stat_shape, 0.5, dtype=np.float32),
"min": np.zeros(stat_shape, dtype=np.float32),
"std": np.full(stat_shape, 0.25, dtype=np.float32),
"count": np.array([5]),
}
elif ft["dtype"] in ("float32", "float64", "int64"):