mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
feat(depth shape): ensuring depth maps shape is always including the channel
This commit is contained in:
@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The depth map as a NumPy array (height, width)
|
||||
of type `np.uint16` (raw depth values in millimeters) and rotation.
|
||||
np.ndarray: The depth map as a NumPy array (height, width, 1)
|
||||
of type `np.uint16` (raw depth values in millimeters).
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
@@ -486,6 +486,8 @@ class RealSenseCamera(Camera):
|
||||
depth_frame_raw = frame.get_depth_frame()
|
||||
depth_frame = np.asanyarray(depth_frame_raw.get_data())
|
||||
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
|
||||
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
|
||||
processed_depth_frame = processed_depth_frame[..., np.newaxis]
|
||||
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
@@ -614,7 +616,7 @@ class RealSenseCamera(Camera):
|
||||
"""Read the latest depth frame asynchronously, in metric meters.
|
||||
|
||||
Mirrors :meth:`async_read` but returns the depth stream rather than the
|
||||
color stream. Output is ``np.uint16`` of shape ``(H, W)``.
|
||||
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
@@ -645,7 +647,7 @@ class RealSenseCamera(Camera):
|
||||
"""Return the most recent depth frame in metric meters (peeking).
|
||||
|
||||
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
|
||||
Output is ``np.float32`` of shape ``(H, W)`` in meters.
|
||||
Output is ``np.uint16`` of shape ``(H, W, 1)`` in millimeters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
|
||||
@@ -113,6 +113,10 @@ def quantize_depth(
|
||||
if isinstance(depth, torch.Tensor):
|
||||
depth = depth.detach().cpu().numpy()
|
||||
|
||||
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
|
||||
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
|
||||
depth = depth.squeeze()
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
|
||||
# Convert depth_min, depth_max, and shift to the resolved input unit.
|
||||
@@ -192,6 +196,10 @@ def dequantize_depth(
|
||||
else:
|
||||
depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m
|
||||
depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False)
|
||||
|
||||
# Add single-channel dim: (H, W) → (H, W, 1)
|
||||
if depth_m.ndim == 2:
|
||||
depth_m = depth_m[..., np.newaxis]
|
||||
|
||||
# Return depth as float32 meters.
|
||||
if output_unit == "m":
|
||||
|
||||
@@ -321,7 +321,7 @@ def validate_feature_image_or_video(
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape (C, H, W).
|
||||
expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
|
||||
value: The image data to validate.
|
||||
|
||||
Returns:
|
||||
@@ -330,20 +330,12 @@ def validate_feature_image_or_video(
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = tuple(value.shape)
|
||||
expected = tuple(expected_shape)
|
||||
if len(expected) == 2:
|
||||
# Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1)
|
||||
h, w = expected
|
||||
valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)}
|
||||
if not valid:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n"
|
||||
elif len(expected) == 3:
|
||||
c, h, w = expected
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n"
|
||||
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected_shape}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -573,7 +573,7 @@ class _CameraEncoderThread(threading.Thread):
|
||||
|
||||
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if not self.is_depth and frame_data.dtype != np.uint8:
|
||||
@@ -699,7 +699,7 @@ class StreamingVideoEncoder:
|
||||
|
||||
Args:
|
||||
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
||||
depth_video_keys: List of video feature keys that carry depth maps (e.g. ["observation.depth.laptop"])
|
||||
depth_video_keys: List of video feature keys that carry depth maps (e.g. ["observation.images.laptop_depth"])
|
||||
temp_dir: Base directory for temporary MP4 files
|
||||
"""
|
||||
if self._episode_active:
|
||||
|
||||
@@ -55,10 +55,10 @@ IMAGE_FEATURES = {
|
||||
|
||||
DEPTH_FEATURES = {
|
||||
**SIMPLE_FEATURES,
|
||||
"observation.depth.laptop": {
|
||||
"observation.images.laptop_depth": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96),
|
||||
"names": ["height", "width"],
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
@@ -69,11 +69,13 @@ def _make_dummy_stats(features: dict) -> dict:
|
||||
stats = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] in ("image", "video"):
|
||||
channels = ft["shape"][-1]
|
||||
stat_shape = (channels, 1, 1)
|
||||
stats[key] = {
|
||||
"max": np.ones((3, 1, 1), dtype=np.float32),
|
||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
|
||||
"min": np.zeros((3, 1, 1), dtype=np.float32),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
|
||||
"max": np.ones(stat_shape, dtype=np.float32),
|
||||
"mean": np.full(stat_shape, 0.5, dtype=np.float32),
|
||||
"min": np.zeros(stat_shape, dtype=np.float32),
|
||||
"std": np.full(stat_shape, 0.25, dtype=np.float32),
|
||||
"count": np.array([5]),
|
||||
}
|
||||
elif ft["dtype"] in ("float32", "float64", "int64"):
|
||||
|
||||
Reference in New Issue
Block a user