mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
feat(batched dequantization): optimizing dequantize_depth for torch based batched dequantization
This commit is contained in:
@@ -266,7 +266,6 @@ class DatasetReader:
|
||||
depth_max=depth_encoder.depth_max,
|
||||
shift=depth_encoder.shift,
|
||||
use_log=depth_encoder.use_log,
|
||||
output_tensor=True,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
|
||||
@@ -143,72 +143,115 @@ def quantize_depth(
|
||||
|
||||
|
||||
def dequantize_depth(
|
||||
quantized: NDArray[np.uint16] | av.VideoFrame,
|
||||
quantized: NDArray[np.uint16] | av.VideoFrame | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
output_tensor: bool = False,
|
||||
output_tensor: bool = True,
|
||||
output_channel_last: bool = False,
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
|
||||
"""Inverse of :func:`quantize_depth`.
|
||||
|
||||
Tuning arguments **must match** :func:`quantize_depth`.
|
||||
|
||||
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
|
||||
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
|
||||
the requested output unit.
|
||||
the requested output unit. Tuning arguments **must match** :func:`quantize_depth`.
|
||||
|
||||
Accepted input layouts :
|
||||
|
||||
- ``(H, W, 1)`` or ``(H, W)`` — single frame with channel-last.
|
||||
- ``(..., 1, H, W)`` — batched frames with channel-first.
|
||||
- ``(..., H, W, 1)`` — batched frames with channel-last.
|
||||
Output layout is determined by ``output_channel_last``.
|
||||
|
||||
Args:
|
||||
quantized: 12-bit codes ``[0, DEPTH_QMAX]``, ``dtype=uint16``.
|
||||
quantized: 12-bit codes in ``[0, DEPTH_QMAX]``. ``np.ndarray``,
|
||||
``av.VideoFrame``, or ``torch.Tensor`` (any integer or float dtype).
|
||||
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
|
||||
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
|
||||
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
|
||||
pix_fmt: Pixel format used to extract the plane from an ``av.VideoFrame``.
|
||||
output_unit: ``"mm"`` returns ``uint16`` millimetres (rint, clip
|
||||
``[0, 65535]``) when returning a numpy array, or ``float32`` mm when
|
||||
``output_tensor=True``. ``"m"`` returns ``float32`` metres in
|
||||
``[depth_min, depth_max]``.
|
||||
output_tensor: If True, return a torch.Tensor instead of a numpy array.
|
||||
output_tensor: If True, return a ``torch.Tensor`` instead of a numpy array.
|
||||
|
||||
Returns:
|
||||
Depth map in the requested unit and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
ValueError: If ``output_unit`` is not ``\"m\"`` or ``\"mm\"``.
|
||||
"""
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
|
||||
if isinstance(quantized, av.VideoFrame):
|
||||
quantized = quantized.to_ndarray(format=pix_fmt)
|
||||
|
||||
norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX
|
||||
|
||||
depth_min_m = np.float32(depth_min)
|
||||
depth_max_m = np.float32(depth_max)
|
||||
shift_m = np.float32(shift)
|
||||
|
||||
# The de-normalization and de-quantization is performed in meters (convenience choice).
|
||||
# Compute the scale and offset first.
|
||||
depth_min_m = float(depth_min)
|
||||
depth_max_m = float(depth_max)
|
||||
shift_m = float(shift)
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_m + shift_m))
|
||||
log_max = math.log(float(depth_max_m + shift_m))
|
||||
depth_m = np.exp(norm * (log_max - log_min) + log_min) - shift_m
|
||||
log_min = math.log(depth_min_m + shift_m)
|
||||
log_max = math.log(depth_max_m + shift_m)
|
||||
scale = (log_max - log_min) / DEPTH_QMAX
|
||||
offset = log_min
|
||||
else:
|
||||
depth_m = norm * (depth_max_m - depth_min_m) + depth_min_m
|
||||
depth_m = np.clip(depth_m, depth_min_m, depth_max_m).astype(np.float32, copy=False)
|
||||
scale = (depth_max_m - depth_min_m) / DEPTH_QMAX
|
||||
offset = depth_min_m
|
||||
|
||||
# Add single-channel dim: (H, W) → (H, W, 1)
|
||||
if depth_m.ndim == 2:
|
||||
depth_m = depth_m[..., np.newaxis]
|
||||
# ── Torch path: stay on the input device, single fp32 allocation. ────────
|
||||
if isinstance(quantized, torch.Tensor):
|
||||
|
||||
if quantized.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
quantized = quantized.squeeze(-3) if quantized.shape[-3] == 1 else quantized.squeeze(-1)
|
||||
|
||||
# Single allocation we own; everything else is in-place.
|
||||
buf = quantized.to(dtype=torch.float32, copy=True)
|
||||
buf.mul_(scale).add_(offset)
|
||||
if use_log:
|
||||
buf.exp_().sub_(shift_m)
|
||||
buf.clamp_(depth_min_m, depth_max_m)
|
||||
buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3)
|
||||
|
||||
if output_unit == "m":
|
||||
return buf if output_tensor else buf.cpu().numpy()
|
||||
|
||||
# mm path: round + clamp in float32, skipping the uint16 round-trip
|
||||
# when returning a tensor (torch.uint16 is poorly supported).
|
||||
buf.mul_(_MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
|
||||
if output_tensor:
|
||||
return buf
|
||||
return buf.cpu().numpy().astype(np.uint16, copy=False)
|
||||
|
||||
# ── NumPy path: single fp32 allocation, ``out=`` for in-place math. ─────
|
||||
arr = np.asarray(quantized)
|
||||
if arr.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
arr = np.squeeze(arr, axis=-3) if arr.shape[-3] == 1 else np.squeeze(arr, axis=-1)
|
||||
|
||||
buf = np.empty(arr.shape, dtype=np.float32)
|
||||
np.multiply(arr, scale, out=buf)
|
||||
np.add(buf, offset, out=buf)
|
||||
if use_log:
|
||||
np.exp(buf, out=buf)
|
||||
np.subtract(buf, shift_m, out=buf)
|
||||
np.clip(buf, depth_min_m, depth_max_m, out=buf)
|
||||
buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3)
|
||||
|
||||
# Return depth as float32 meters.
|
||||
if output_unit == "m":
|
||||
return torch.from_numpy(depth_m) if output_tensor else depth_m
|
||||
return torch.from_numpy(buf) if output_tensor else buf
|
||||
|
||||
# Return depth as uint16 millimeters.
|
||||
mm = np.rint(depth_m * _MM_PER_METRE).clip(0, _UINT16_MAX).astype(np.uint16, copy=False)
|
||||
np.multiply(buf, _MM_PER_METRE, out=buf)
|
||||
np.rint(buf, out=buf)
|
||||
np.clip(buf, 0.0, _UINT16_MAX, out=buf)
|
||||
if output_tensor:
|
||||
# torch.uint16 support is very limited, we convert to float32 instead.
|
||||
return torch.from_numpy(mm.astype(np.float32))
|
||||
else:
|
||||
return mm
|
||||
# torch.uint16 support is very limited; return float32 millimetres.
|
||||
return torch.from_numpy(buf)
|
||||
return buf.astype(np.uint16, copy=False)
|
||||
|
||||
Reference in New Issue
Block a user