mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-25 04:07:02 +00:00
fix(signle channel squeeze): fixing single channel squeezing
This commit is contained in:
@@ -36,6 +36,7 @@ from lerobot.configs.video import (
|
||||
DEPTH_QMAX,
|
||||
)
|
||||
|
||||
from .image_writer import squeeze_single_channel
|
||||
from .pyav_utils import write_u16_plane
|
||||
|
||||
_MM_PER_METRE = 1000.0
|
||||
@@ -119,8 +120,7 @@ def quantize_depth(
|
||||
depth = depth.detach().cpu().numpy()
|
||||
|
||||
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
|
||||
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
|
||||
depth = depth.squeeze()
|
||||
depth = squeeze_single_channel(depth)
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
|
||||
|
||||
@@ -41,6 +41,19 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def squeeze_single_channel(array: np.ndarray) -> np.ndarray:
|
||||
"""Drop a leading or trailing singleton channel dim: ``(1, H, W)`` / ``(H, W, 1)`` -> ``(H, W)``.
|
||||
|
||||
Unlike ``array.squeeze()``, this only removes the channel axis, never an ``H`` or ``W`` of size 1.
|
||||
"""
|
||||
if array.ndim == 3:
|
||||
if array.shape[0] == 1:
|
||||
return array[0]
|
||||
if array.shape[-1] == 1:
|
||||
return array[..., 0]
|
||||
return array
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
|
||||
|
||||
@@ -62,11 +75,7 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
|
||||
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
|
||||
# caller emits (H, W), (1, H, W), or (H, W, 1).
|
||||
if image_array.ndim == 3:
|
||||
if image_array.shape[0] == 1:
|
||||
image_array = image_array[0]
|
||||
elif image_array.shape[-1] == 1:
|
||||
image_array = image_array[..., 0]
|
||||
image_array = squeeze_single_channel(image_array)
|
||||
|
||||
if image_array.ndim == 2:
|
||||
if image_array.dtype not in [np.uint16, np.float32]:
|
||||
@@ -112,7 +121,8 @@ def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
|
||||
return {"compress_level": compress_level}
|
||||
if suffix in (".tif", ".tiff"):
|
||||
return {"compression": "raw"}
|
||||
return {}
|
||||
else:
|
||||
raise ValueError(f"Unsupported image file extension: {suffix}")
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
|
||||
Reference in New Issue
Block a user