mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
feat(depth): extend quantization tools to better fit the encoding/decoding pipeline
This commit is contained in:
@@ -20,6 +20,7 @@ Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
@@ -32,9 +33,14 @@ from lerobot.configs.video import (
|
||||
DEPTH_QMAX,
|
||||
)
|
||||
|
||||
from .pyav_utils import write_u16_plane
|
||||
|
||||
_MM_PER_METRE = 1000.0
|
||||
_UINT16_MAX = 65535
|
||||
|
||||
# Pixel format supported by the depth encode/decode helpers.
|
||||
DEPTH_PIX_FMT: str = "gray12le"
|
||||
|
||||
|
||||
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
"""Ensure ``log(depth_min + shift)`` is finite."""
|
||||
@@ -46,33 +52,25 @@ def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
|
||||
|
||||
def _depth_input_to_float32_and_unit(
|
||||
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor,
|
||||
depth: NDArray[np.integer] | NDArray[np.floating],
|
||||
input_unit: Literal["auto", "m", "mm"],
|
||||
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
|
||||
"""Depth as float32 in the chosen unit, plus the resolved unit."""
|
||||
if isinstance(depth, torch.Tensor):
|
||||
t = depth.detach().cpu()
|
||||
arr = t.numpy()
|
||||
is_floating = t.is_floating_point()
|
||||
else:
|
||||
arr = np.asarray(depth)
|
||||
is_floating = np.issubdtype(arr.dtype, np.floating)
|
||||
|
||||
resolved_unit = ("m" if is_floating else "mm") if input_unit == "auto" else input_unit
|
||||
|
||||
# Convert to float32 to keep typing consistency
|
||||
return np.asarray(arr, dtype=np.float32, order="K"), resolved_unit
|
||||
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
|
||||
resolved_unit = (
|
||||
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
|
||||
)
|
||||
return depth.astype(np.float32, order="K"), resolved_unit
|
||||
|
||||
|
||||
def quantize_depth(
|
||||
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor,
|
||||
depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
*,
|
||||
video_backend: str | None = "pyav",
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> NDArray[np.uint16]:
|
||||
) -> NDArray[np.uint16] | av.VideoFrame:
|
||||
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
|
||||
|
||||
Depth maps are packed into 12-bit integer frames so they fit in standard
|
||||
@@ -98,6 +96,7 @@ def quantize_depth(
|
||||
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
|
||||
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
|
||||
use_log: If ``True`` (default), quantize in log space.
|
||||
video_backend: Video backend to use for encoding. Defaults to "pyav".
|
||||
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
|
||||
|
||||
Returns:
|
||||
@@ -111,11 +110,17 @@ def quantize_depth(
|
||||
if input_unit not in ("auto", "m", "mm"):
|
||||
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
|
||||
|
||||
if isinstance(depth, torch.Tensor):
|
||||
depth = depth.detach().cpu().numpy()
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
|
||||
# Convert depth_min, depth_max, and shift to the resolved input unit.
|
||||
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
# Normalization and quantization is performed in the resolved input unit.
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_u + shift_u))
|
||||
@@ -124,19 +129,25 @@ def quantize_depth(
|
||||
else:
|
||||
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
|
||||
|
||||
out = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX)
|
||||
return out.astype(np.uint16, copy=False)
|
||||
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
|
||||
|
||||
if video_backend == "pyav":
|
||||
frame = av.VideoFrame.from_ndarray(quantized, format=DEPTH_PIX_FMT)
|
||||
write_u16_plane(frame.planes[0], quantized)
|
||||
return frame
|
||||
else:
|
||||
return quantized
|
||||
|
||||
|
||||
def dequantize_depth(
|
||||
quantized: NDArray[np.uint16] | torch.Tensor,
|
||||
quantized: NDArray[np.uint16] | av.VideoFrame,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
*,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32]:
|
||||
output_tensor: bool = False,
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
|
||||
"""Inverse of :func:`quantize_depth`.
|
||||
|
||||
Tuning arguments **must match** :func:`quantize_depth`.
|
||||
@@ -151,6 +162,7 @@ def dequantize_depth(
|
||||
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
|
||||
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
|
||||
``[depth_min, depth_max]``.
|
||||
output_tensor: If True, return a torch.Tensor instead of a numpy array.
|
||||
|
||||
Returns:
|
||||
Depth map in the requested unit and dtype.
|
||||
@@ -162,25 +174,33 @@ def dequantize_depth(
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
|
||||
if isinstance(quantized, torch.Tensor):
|
||||
quantized = quantized.detach().cpu().numpy()
|
||||
q = np.asarray(quantized, dtype=np.uint16, order="K")
|
||||
norm = q.astype(np.float32, copy=False) / DEPTH_QMAX
|
||||
if isinstance(quantized, av.VideoFrame):
|
||||
quantized = quantized.to_ndarray(format=DEPTH_PIX_FMT)
|
||||
|
||||
depth_min_mm = np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_mm = np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_mm = np.float32(shift * _MM_PER_METRE)
|
||||
norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX
|
||||
|
||||
depth_min_m = np.float32(depth_min)
|
||||
depth_max_m = np.float32(depth_max)
|
||||
shift_m = np.float32(shift)
|
||||
|
||||
# The de-normalization and de-quantization is performed in meters (convenience choice).
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_mm + shift_mm))
|
||||
log_max = math.log(float(depth_max_mm + shift_mm))
|
||||
depth_mm = np.exp(norm * (log_max - log_min) + log_min) - shift_mm
|
||||
log_min = math.log(float(depth_min_m + shift_m))
|
||||
log_max = math.log(float(depth_max_m + shift_m))
|
||||
depth_m = np.exp(norm * (log_max - log_min) + log_min) - shift_m
|
||||
else:
|
||||
depth_mm = norm * (depth_max_mm - depth_min_mm) + depth_min_mm
|
||||
depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m
|
||||
depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False)
|
||||
|
||||
depth_mm = np.clip(depth_mm, depth_min_mm, depth_max_mm).astype(np.float32, copy=False)
|
||||
# Return depth as float32 meters.
|
||||
if output_unit == "m":
|
||||
return (depth_mm / np.float32(_MM_PER_METRE)).astype(np.float32, copy=False)
|
||||
mm = np.rint(depth_mm).clip(0, _UINT16_MAX)
|
||||
return mm.astype(np.uint16, copy=False)
|
||||
return torch.from_numpy(depth_m) if output_tensor else depth_m
|
||||
|
||||
# Return depth as uint16 millimeters.
|
||||
mm = np.rint(depth_m * _MM_PER_METRE).clip(0, _UINT16_MAX).astype(np.uint16, copy=False)
|
||||
if output_tensor:
|
||||
# torch.uint16 support is very limited, we convert to float32 instead.
|
||||
return torch.from_numpy(mm.astype(np.float32))
|
||||
else:
|
||||
return mm
|
||||
|
||||
@@ -24,6 +24,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,6 +32,16 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
|
||||
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
|
||||
height, width = src.shape
|
||||
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
|
||||
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
|
||||
if fill_value is not None:
|
||||
dst.fill(fill_value)
|
||||
dst[:, :width] = src
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
|
||||
Reference in New Issue
Block a user