diff --git a/src/lerobot/datasets/depth_utils.py b/src/lerobot/datasets/depth_utils.py index e3ab32982..f7a97b43d 100644 --- a/src/lerobot/datasets/depth_utils.py +++ b/src/lerobot/datasets/depth_utils.py @@ -36,6 +36,7 @@ from lerobot.configs.video import ( DEPTH_QMAX, ) +from .image_writer import squeeze_single_channel from .pyav_utils import write_u16_plane _MM_PER_METRE = 1000.0 @@ -119,8 +120,7 @@ def quantize_depth( depth = depth.detach().cpu().numpy() # Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W) - if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1): - depth = depth.squeeze() + depth = squeeze_single_channel(depth) depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit) diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index f88111493..41790b46a 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -41,6 +41,19 @@ def safe_stop_image_writer(func): return wrapper +def squeeze_single_channel(array: np.ndarray) -> np.ndarray: + """Drop a leading or trailing singleton channel dim: ``(1, H, W)`` / ``(H, W, 1)`` -> ``(H, W)``. + + Unlike ``array.squeeze()``, this only removes the channel axis, never an ``H`` or ``W`` of size 1. + """ + if array.ndim == 3: + if array.shape[0] == 1: + return array[0] + if array.shape[-1] == 1: + return array[..., 0] + return array + + def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: """Convert a NumPy array to a PIL Image, preserving precision for grayscale. @@ -62,11 +75,7 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) # Squeeze 3D single-channel inputs to 2D so depth maps work whether the # caller emits (H, W), (1, H, W), or (H, W, 1). - if image_array.ndim == 3: - if image_array.shape[0] == 1: - image_array = image_array[0] - elif image_array.shape[-1] == 1: - image_array = image_array[..., 0] + image_array = squeeze_single_channel(image_array) if image_array.ndim == 2: if image_array.dtype not in [np.uint16, np.float32]: @@ -112,7 +121,8 @@ def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict: return {"compress_level": compress_level} if suffix in (".tif", ".tiff"): return {"compression": "raw"} - return {} + else: + raise ValueError(f"Unsupported image file extension: {suffix}") def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):