fix(signle channel squeeze): fixing single channel squeezing

This commit is contained in:
CarolinePascal
2026-06-24 17:54:29 +02:00
parent 31aa0766a2
commit 5868ca6492
2 changed files with 18 additions and 8 deletions
+2 -2
View File
@@ -36,6 +36,7 @@ from lerobot.configs.video import (
DEPTH_QMAX,
)
from .image_writer import squeeze_single_channel
from .pyav_utils import write_u16_plane
_MM_PER_METRE = 1000.0
@@ -119,8 +120,7 @@ def quantize_depth(
depth = depth.detach().cpu().numpy()
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
depth = depth.squeeze()
depth = squeeze_single_channel(depth)
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
+16 -6
View File
@@ -41,6 +41,19 @@ def safe_stop_image_writer(func):
return wrapper
def squeeze_single_channel(array: np.ndarray) -> np.ndarray:
"""Drop a leading or trailing singleton channel dim: ``(1, H, W)`` / ``(H, W, 1)`` -> ``(H, W)``.
Unlike ``array.squeeze()``, this only removes the channel axis, never an ``H`` or ``W`` of size 1.
"""
if array.ndim == 3:
if array.shape[0] == 1:
return array[0]
if array.shape[-1] == 1:
return array[..., 0]
return array
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
@@ -62,11 +75,7 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
# caller emits (H, W), (1, H, W), or (H, W, 1).
if image_array.ndim == 3:
if image_array.shape[0] == 1:
image_array = image_array[0]
elif image_array.shape[-1] == 1:
image_array = image_array[..., 0]
image_array = squeeze_single_channel(image_array)
if image_array.ndim == 2:
if image_array.dtype not in [np.uint16, np.float32]:
@@ -112,7 +121,8 @@ def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
return {"compress_level": compress_level}
if suffix in (".tif", ".tiff"):
return {"compression": "raw"}
return {}
else:
raise ValueError(f"Unsupported image file extension: {suffix}")
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):