diff --git a/src/lerobot/datasets/depth_utils.py b/src/lerobot/datasets/depth_utils.py index 0fd9e624b..cbded58de 100644 --- a/src/lerobot/datasets/depth_utils.py +++ b/src/lerobot/datasets/depth_utils.py @@ -20,6 +20,7 @@ Depth encoding/decoding helpers for :class:`VideoEncoderConfig`. import math from typing import Literal +import av import numpy as np import torch from numpy.typing import NDArray @@ -32,9 +33,14 @@ from lerobot.configs.video import ( DEPTH_QMAX, ) +from .pyav_utils import write_u16_plane + _MM_PER_METRE = 1000.0 _UINT16_MAX = 65535 +# Pixel format supported by the depth encode/decode helpers. +DEPTH_PIX_FMT: str = "gray12le" + def _validate_log_quant_params(depth_min: float, shift: float) -> None: """Ensure ``log(depth_min + shift)`` is finite.""" @@ -46,33 +52,25 @@ def _validate_log_quant_params(depth_min: float, shift: float) -> None: def _depth_input_to_float32_and_unit( - depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor, + depth: NDArray[np.integer] | NDArray[np.floating], input_unit: Literal["auto", "m", "mm"], ) -> tuple[NDArray[np.float32], Literal["m", "mm"]]: - """Depth as float32 in the chosen unit, plus the resolved unit.""" - if isinstance(depth, torch.Tensor): - t = depth.detach().cpu() - arr = t.numpy() - is_floating = t.is_floating_point() - else: - arr = np.asarray(depth) - is_floating = np.issubdtype(arr.dtype, np.floating) - - resolved_unit = ("m" if is_floating else "mm") if input_unit == "auto" else input_unit - - # Convert to float32 to keep typing consistency - return np.asarray(arr, dtype=np.float32, order="K"), resolved_unit + """Convert depth to float32 in the chosen unit, and return the resolved unit.""" + resolved_unit = ( + ("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit + ) + return depth.astype(np.float32, order="K"), resolved_unit def quantize_depth( - depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor, + depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor, depth_min: float = DEFAULT_DEPTH_MIN, depth_max: float = DEFAULT_DEPTH_MAX, shift: float = DEFAULT_DEPTH_SHIFT, use_log: bool = DEFAULT_DEPTH_USE_LOG, - *, + video_backend: str | None = "pyav", input_unit: Literal["auto", "m", "mm"] = "auto", -) -> NDArray[np.uint16]: +) -> NDArray[np.uint16] | av.VideoFrame: """Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``). Depth maps are packed into 12-bit integer frames so they fit in standard @@ -98,6 +96,7 @@ def quantize_depth( depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`. shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``. use_log: If ``True`` (default), quantize in log space. + video_backend: Video backend to use for encoding. Defaults to "pyav". input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``). Returns: @@ -111,11 +110,17 @@ def quantize_depth( if input_unit not in ("auto", "m", "mm"): raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}") + if isinstance(depth, torch.Tensor): + depth = depth.detach().cpu().numpy() + depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit) + + # Convert depth_min, depth_max, and shift to the resolved input unit. depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE) depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE) shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE) + # Normalization and quantization is performed in the resolved input unit. if use_log: _validate_log_quant_params(depth_min, shift) log_min = math.log(float(depth_min_u + shift_u)) @@ -124,19 +129,25 @@ def quantize_depth( else: norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u) - out = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX) - return out.astype(np.uint16, copy=False) + quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False) + + if video_backend == "pyav": + frame = av.VideoFrame.from_ndarray(quantized, format=DEPTH_PIX_FMT) + write_u16_plane(frame.planes[0], quantized) + return frame + else: + return quantized def dequantize_depth( - quantized: NDArray[np.uint16] | torch.Tensor, + quantized: NDArray[np.uint16] | av.VideoFrame, depth_min: float = DEFAULT_DEPTH_MIN, depth_max: float = DEFAULT_DEPTH_MAX, shift: float = DEFAULT_DEPTH_SHIFT, use_log: bool = DEFAULT_DEPTH_USE_LOG, - *, output_unit: Literal["m", "mm"] = "mm", -) -> NDArray[np.uint16] | NDArray[np.float32]: + output_tensor: bool = False, +) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor: """Inverse of :func:`quantize_depth`. Tuning arguments **must match** :func:`quantize_depth`. @@ -151,6 +162,7 @@ def dequantize_depth( output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip ``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in ``[depth_min, depth_max]``. + output_tensor: If True, return a torch.Tensor instead of a numpy array. Returns: Depth map in the requested unit and dtype. @@ -162,25 +174,33 @@ def dequantize_depth( if output_unit not in ("m", "mm"): raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}") - if isinstance(quantized, torch.Tensor): - quantized = quantized.detach().cpu().numpy() - q = np.asarray(quantized, dtype=np.uint16, order="K") - norm = q.astype(np.float32, copy=False) / DEPTH_QMAX + if isinstance(quantized, av.VideoFrame): + quantized = quantized.to_ndarray(format=DEPTH_PIX_FMT) - depth_min_mm = np.float32(depth_min * _MM_PER_METRE) - depth_max_mm = np.float32(depth_max * _MM_PER_METRE) - shift_mm = np.float32(shift * _MM_PER_METRE) + norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX + depth_min_m = np.float32(depth_min) + depth_max_m = np.float32(depth_max) + shift_m = np.float32(shift) + + # The de-normalization and de-quantization is performed in meters (convenience choice). if use_log: _validate_log_quant_params(depth_min, shift) - log_min = math.log(float(depth_min_mm + shift_mm)) - log_max = math.log(float(depth_max_mm + shift_mm)) - depth_mm = np.exp(norm * (log_max - log_min) + log_min) - shift_mm + log_min = math.log(float(depth_min_m + shift_m)) + log_max = math.log(float(depth_max_m + shift_m)) + depth_m = np.exp(norm * (log_max - log_min) + log_min) - shift_m else: - depth_mm = norm * (depth_max_mm - depth_min_mm) + depth_min_mm + depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m + depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False) - depth_mm = np.clip(depth_mm, depth_min_mm, depth_max_mm).astype(np.float32, copy=False) + # Return depth as float32 meters. if output_unit == "m": - return (depth_mm / np.float32(_MM_PER_METRE)).astype(np.float32, copy=False) - mm = np.rint(depth_mm).clip(0, _UINT16_MAX) - return mm.astype(np.uint16, copy=False) + return torch.from_numpy(depth_m) if output_tensor else depth_m + + # Return depth as uint16 millimeters. + mm = np.rint(depth_m * _MM_PER_METRE).clip(0, _UINT16_MAX).astype(np.uint16, copy=False) + if output_tensor: + # torch.uint16 support is very limited, we convert to float32 instead. + return torch.from_numpy(mm.astype(np.float32)) + else: + return mm diff --git a/src/lerobot/datasets/pyav_utils.py b/src/lerobot/datasets/pyav_utils.py index d291f8b40..1fbbe5f89 100644 --- a/src/lerobot/datasets/pyav_utils.py +++ b/src/lerobot/datasets/pyav_utils.py @@ -24,6 +24,7 @@ import logging from typing import Any import av +import numpy as np logger = logging.getLogger(__name__) @@ -31,6 +32,16 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE") FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64") +def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None: + """Copy ``src`` into a uint16 plane respecting FFmpeg line padding.""" + height, width = src.shape + stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize + dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16) + if fill_value is not None: + dst.fill(fill_value) + dst[:, :width] = src + + @functools.cache def get_codec(vcodec: str) -> av.codec.Codec | None: """PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""