feat(batched dequantization): optimizing dequantize_depth for torch based batched dequantization

This commit is contained in:
CarolinePascal
2026-05-26 16:58:44 +02:00
parent 263108d6c1
commit 92497dfcd8
2 changed files with 77 additions and 35 deletions
-1
View File
@@ -266,7 +266,6 @@ class DatasetReader:
depth_max=depth_encoder.depth_max,
shift=depth_encoder.shift,
use_log=depth_encoder.use_log,
output_tensor=True,
)
return vid_key, frames.squeeze(0)
+77 -34
View File
@@ -143,72 +143,115 @@ def quantize_depth(
def dequantize_depth(
quantized: NDArray[np.uint16] | av.VideoFrame,
quantized: NDArray[np.uint16] | av.VideoFrame | torch.Tensor,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
output_unit: Literal["m", "mm"] = "mm",
output_tensor: bool = False,
output_tensor: bool = True,
output_channel_last: bool = False,
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
"""Inverse of :func:`quantize_depth`.
Tuning arguments **must match** :func:`quantize_depth`.
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
the requested output unit.
the requested output unit. Tuning arguments **must match** :func:`quantize_depth`.
Accepted input layouts :
- ``(H, W, 1)`` or ``(H, W)`` — single frame with channel-last.
- ``(..., 1, H, W)`` — batched frames with channel-first.
- ``(..., H, W, 1)`` — batched frames with channel-last.
Output layout is determined by ``output_channel_last``.
Args:
quantized: 12-bit codes ``[0, DEPTH_QMAX]``, ``dtype=uint16``.
quantized: 12-bit codes in ``[0, DEPTH_QMAX]``. ``np.ndarray``,
``av.VideoFrame``, or ``torch.Tensor`` (any integer or float dtype).
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
pix_fmt: Pixel format used to extract the plane from an ``av.VideoFrame``.
output_unit: ``"mm"`` returns ``uint16`` millimetres (rint, clip
``[0, 65535]``) when returning a numpy array, or ``float32`` mm when
``output_tensor=True``. ``"m"`` returns ``float32`` metres in
``[depth_min, depth_max]``.
output_tensor: If True, return a torch.Tensor instead of a numpy array.
output_tensor: If True, return a ``torch.Tensor`` instead of a numpy array.
Returns:
Depth map in the requested unit and dtype.
Raises:
ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``.
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
ValueError: If ``output_unit`` is not ``\"m\"`` or ``\"mm\"``.
"""
if output_unit not in ("m", "mm"):
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
if use_log:
_validate_log_quant_params(depth_min, shift)
if isinstance(quantized, av.VideoFrame):
quantized = quantized.to_ndarray(format=pix_fmt)
norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX
depth_min_m = np.float32(depth_min)
depth_max_m = np.float32(depth_max)
shift_m = np.float32(shift)
# The de-normalization and de-quantization is performed in meters (convenience choice).
# Compute the scale and offset first.
depth_min_m = float(depth_min)
depth_max_m = float(depth_max)
shift_m = float(shift)
if use_log:
_validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_m + shift_m))
log_max = math.log(float(depth_max_m + shift_m))
depth_m = np.exp(norm * (log_max - log_min) + log_min) - shift_m
log_min = math.log(depth_min_m + shift_m)
log_max = math.log(depth_max_m + shift_m)
scale = (log_max - log_min) / DEPTH_QMAX
offset = log_min
else:
depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m
depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False)
scale = (depth_max_m - depth_min_m) / DEPTH_QMAX
offset = depth_min_m
# Add single-channel dim: (H, W) → (H, W, 1)
if depth_m.ndim == 2:
depth_m = depth_m[..., np.newaxis]
# ── Torch path: stay on the input device, single fp32 allocation. ────────
if isinstance(quantized, torch.Tensor):
if quantized.ndim >= 3:
# Drop the single-channel dimension so the math runs on (..., H, W).
quantized = quantized.squeeze(-3) if quantized.shape[-3] == 1 else quantized.squeeze(-1)
# Single allocation we own; everything else is in-place.
buf = quantized.to(dtype=torch.float32, copy=True)
buf.mul_(scale).add_(offset)
if use_log:
buf.exp_().sub_(shift_m)
buf.clamp_(depth_min_m, depth_max_m)
buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3)
if output_unit == "m":
return buf if output_tensor else buf.cpu().numpy()
# mm path: round + clamp in float32, skipping the uint16 round-trip
# when returning a tensor (torch.uint16 is poorly supported).
buf.mul_(_MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
if output_tensor:
return buf
return buf.cpu().numpy().astype(np.uint16, copy=False)
# ── NumPy path: single fp32 allocation, ``out=`` for in-place math. ─────
arr = np.asarray(quantized)
if arr.ndim >= 3:
# Drop the single-channel dimension so the math runs on (..., H, W).
arr = np.squeeze(arr, axis=-3) if arr.shape[-3] == 1 else np.squeeze(arr, axis=-1)
buf = np.empty(arr.shape, dtype=np.float32)
np.multiply(arr, scale, out=buf)
np.add(buf, offset, out=buf)
if use_log:
np.exp(buf, out=buf)
np.subtract(buf, shift_m, out=buf)
np.clip(buf, depth_min_m, depth_max_m, out=buf)
buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3)
# Return depth as float32 meters.
if output_unit == "m":
return torch.from_numpy(depth_m) if output_tensor else depth_m
return torch.from_numpy(buf) if output_tensor else buf
# Return depth as uint16 millimeters.
mm = np.rint(depth_m * _MM_PER_METRE).clip(0, _UINT16_MAX).astype(np.uint16, copy=False)
np.multiply(buf, _MM_PER_METRE, out=buf)
np.rint(buf, out=buf)
np.clip(buf, 0.0, _UINT16_MAX, out=buf)
if output_tensor:
# torch.uint16 support is very limited, we convert to float32 instead.
return torch.from_numpy(mm.astype(np.float32))
else:
return mm
# torch.uint16 support is very limited; return float32 millimetres.
return torch.from_numpy(buf)
return buf.astype(np.uint16, copy=False)