feat(depth): extend quantization tools to better fit the encoding/decoding pipeline

This commit is contained in:
CarolinePascal
2026-05-19 17:10:47 +02:00
parent b960524d93
commit 0cc5162078
2 changed files with 68 additions and 37 deletions
+57 -37
View File
@@ -20,6 +20,7 @@ Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
import math import math
from typing import Literal from typing import Literal
import av
import numpy as np import numpy as np
import torch import torch
from numpy.typing import NDArray from numpy.typing import NDArray
@@ -32,9 +33,14 @@ from lerobot.configs.video import (
DEPTH_QMAX, DEPTH_QMAX,
) )
from .pyav_utils import write_u16_plane
_MM_PER_METRE = 1000.0 _MM_PER_METRE = 1000.0
_UINT16_MAX = 65535 _UINT16_MAX = 65535
# Pixel format supported by the depth encode/decode helpers.
DEPTH_PIX_FMT: str = "gray12le"
def _validate_log_quant_params(depth_min: float, shift: float) -> None: def _validate_log_quant_params(depth_min: float, shift: float) -> None:
"""Ensure ``log(depth_min + shift)`` is finite.""" """Ensure ``log(depth_min + shift)`` is finite."""
@@ -46,33 +52,25 @@ def _validate_log_quant_params(depth_min: float, shift: float) -> None:
def _depth_input_to_float32_and_unit( def _depth_input_to_float32_and_unit(
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor, depth: NDArray[np.integer] | NDArray[np.floating],
input_unit: Literal["auto", "m", "mm"], input_unit: Literal["auto", "m", "mm"],
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]: ) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
"""Depth as float32 in the chosen unit, plus the resolved unit.""" """Convert depth to float32 in the chosen unit, and return the resolved unit."""
if isinstance(depth, torch.Tensor): resolved_unit = (
t = depth.detach().cpu() ("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
arr = t.numpy() )
is_floating = t.is_floating_point() return depth.astype(np.float32, order="K"), resolved_unit
else:
arr = np.asarray(depth)
is_floating = np.issubdtype(arr.dtype, np.floating)
resolved_unit = ("m" if is_floating else "mm") if input_unit == "auto" else input_unit
# Convert to float32 to keep typing consistency
return np.asarray(arr, dtype=np.float32, order="K"), resolved_unit
def quantize_depth( def quantize_depth(
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor, depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
depth_min: float = DEFAULT_DEPTH_MIN, depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX, depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT, shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG, use_log: bool = DEFAULT_DEPTH_USE_LOG,
*, video_backend: str | None = "pyav",
input_unit: Literal["auto", "m", "mm"] = "auto", input_unit: Literal["auto", "m", "mm"] = "auto",
) -> NDArray[np.uint16]: ) -> NDArray[np.uint16] | av.VideoFrame:
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``). """Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
Depth maps are packed into 12-bit integer frames so they fit in standard Depth maps are packed into 12-bit integer frames so they fit in standard
@@ -98,6 +96,7 @@ def quantize_depth(
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`. depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``. shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
use_log: If ``True`` (default), quantize in log space. use_log: If ``True`` (default), quantize in log space.
video_backend: Video backend to use for encoding. Defaults to "pyav".
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``). input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
Returns: Returns:
@@ -111,11 +110,17 @@ def quantize_depth(
if input_unit not in ("auto", "m", "mm"): if input_unit not in ("auto", "m", "mm"):
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}") raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
if isinstance(depth, torch.Tensor):
depth = depth.detach().cpu().numpy()
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit) depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
# Convert depth_min, depth_max, and shift to the resolved input unit.
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE) depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE) depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE) shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
# Normalization and quantization is performed in the resolved input unit.
if use_log: if use_log:
_validate_log_quant_params(depth_min, shift) _validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_u + shift_u)) log_min = math.log(float(depth_min_u + shift_u))
@@ -124,19 +129,25 @@ def quantize_depth(
else: else:
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u) norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
out = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX) quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
return out.astype(np.uint16, copy=False)
if video_backend == "pyav":
frame = av.VideoFrame.from_ndarray(quantized, format=DEPTH_PIX_FMT)
write_u16_plane(frame.planes[0], quantized)
return frame
else:
return quantized
def dequantize_depth( def dequantize_depth(
quantized: NDArray[np.uint16] | torch.Tensor, quantized: NDArray[np.uint16] | av.VideoFrame,
depth_min: float = DEFAULT_DEPTH_MIN, depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX, depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT, shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG, use_log: bool = DEFAULT_DEPTH_USE_LOG,
*,
output_unit: Literal["m", "mm"] = "mm", output_unit: Literal["m", "mm"] = "mm",
) -> NDArray[np.uint16] | NDArray[np.float32]: output_tensor: bool = False,
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
"""Inverse of :func:`quantize_depth`. """Inverse of :func:`quantize_depth`.
Tuning arguments **must match** :func:`quantize_depth`. Tuning arguments **must match** :func:`quantize_depth`.
@@ -151,6 +162,7 @@ def dequantize_depth(
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in ``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
``[depth_min, depth_max]``. ``[depth_min, depth_max]``.
output_tensor: If True, return a torch.Tensor instead of a numpy array.
Returns: Returns:
Depth map in the requested unit and dtype. Depth map in the requested unit and dtype.
@@ -162,25 +174,33 @@ def dequantize_depth(
if output_unit not in ("m", "mm"): if output_unit not in ("m", "mm"):
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}") raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
if isinstance(quantized, torch.Tensor): if isinstance(quantized, av.VideoFrame):
quantized = quantized.detach().cpu().numpy() quantized = quantized.to_ndarray(format=DEPTH_PIX_FMT)
q = np.asarray(quantized, dtype=np.uint16, order="K")
norm = q.astype(np.float32, copy=False) / DEPTH_QMAX
depth_min_mm = np.float32(depth_min * _MM_PER_METRE) norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX
depth_max_mm = np.float32(depth_max * _MM_PER_METRE)
shift_mm = np.float32(shift * _MM_PER_METRE)
depth_min_m = np.float32(depth_min)
depth_max_m = np.float32(depth_max)
shift_m = np.float32(shift)
# The de-normalization and de-quantization is performed in meters (convenience choice).
if use_log: if use_log:
_validate_log_quant_params(depth_min, shift) _validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_mm + shift_mm)) log_min = math.log(float(depth_min_m + shift_m))
log_max = math.log(float(depth_max_mm + shift_mm)) log_max = math.log(float(depth_max_m + shift_m))
depth_mm = np.exp(norm * (log_max - log_min) + log_min) - shift_mm depth_m = np.exp(norm * (log_max - log_min) + log_min) - shift_m
else: else:
depth_mm = norm * (depth_max_mm - depth_min_mm) + depth_min_mm depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m
depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False)
depth_mm = np.clip(depth_mm, depth_min_mm, depth_max_mm).astype(np.float32, copy=False) # Return depth as float32 meters.
if output_unit == "m": if output_unit == "m":
return (depth_mm / np.float32(_MM_PER_METRE)).astype(np.float32, copy=False) return torch.from_numpy(depth_m) if output_tensor else depth_m
mm = np.rint(depth_mm).clip(0, _UINT16_MAX)
return mm.astype(np.uint16, copy=False) # Return depth as uint16 millimeters.
mm = np.rint(depth_m * _MM_PER_METRE).clip(0, _UINT16_MAX).astype(np.uint16, copy=False)
if output_tensor:
# torch.uint16 support is very limited, we convert to float32 instead.
return torch.from_numpy(mm.astype(np.float32))
else:
return mm
+11
View File
@@ -24,6 +24,7 @@ import logging
from typing import Any from typing import Any
import av import av
import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,6 +32,16 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64") FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
height, width = src.shape
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
if fill_value is not None:
dst.fill(fill_value)
dst[:, :width] = src
@functools.cache @functools.cache
def get_codec(vcodec: str) -> av.codec.Codec | None: def get_codec(vcodec: str) -> av.codec.Codec | None:
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable.""" """PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""