mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-25 20:27:05 +00:00
feat(units): adding constants for depth frames units (m and mm)
This commit is contained in:
@@ -19,7 +19,7 @@ from dataclasses import dataclass, field
|
||||
from lerobot.transforms import ImageTransformsConfig
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
from .video import DEFAULT_DEPTH_UNIT
|
||||
from .video import DEFAULT_DEPTH_UNIT, DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -46,8 +46,10 @@ class DatasetConfig:
|
||||
streaming: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.depth_output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}")
|
||||
if self.depth_output_unit not in (DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT):
|
||||
raise ValueError(
|
||||
f"depth_output_unit must be '{DEPTH_METER_UNIT}' or '{DEPTH_MILLIMETER_UNIT}', got {self.depth_output_unit!r}"
|
||||
)
|
||||
if self.episodes is not None:
|
||||
if any(ep < 0 for ep in self.episodes):
|
||||
raise ValueError(
|
||||
|
||||
@@ -62,7 +62,10 @@ DEFAULT_DEPTH_MAX: float = 10.0
|
||||
DEFAULT_DEPTH_SHIFT: float = 3.5
|
||||
DEFAULT_DEPTH_USE_LOG: bool = True
|
||||
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
|
||||
DEFAULT_DEPTH_UNIT = "mm"
|
||||
|
||||
DEPTH_METER_UNIT: str = "m"
|
||||
DEPTH_MILLIMETER_UNIT: str = "mm"
|
||||
DEFAULT_DEPTH_UNIT: str = DEPTH_MILLIMETER_UNIT
|
||||
|
||||
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
|
||||
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
|
||||
|
||||
@@ -31,6 +31,8 @@ from lerobot.configs.video import (
|
||||
DEFAULT_DEPTH_PIX_FMT,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
DEPTH_METER_UNIT,
|
||||
DEPTH_MILLIMETER_UNIT,
|
||||
DEPTH_QMAX,
|
||||
)
|
||||
|
||||
@@ -51,11 +53,13 @@ def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
|
||||
def _depth_input_to_float32_and_unit(
|
||||
depth: NDArray[np.integer] | NDArray[np.floating],
|
||||
input_unit: Literal["auto", "m", "mm"],
|
||||
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
|
||||
input_unit: Literal["auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT],
|
||||
) -> tuple[NDArray[np.float32], Literal[DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT]]:
|
||||
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
|
||||
resolved_unit = (
|
||||
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
|
||||
(DEPTH_METER_UNIT if np.issubdtype(depth.dtype, np.floating) else DEPTH_MILLIMETER_UNIT)
|
||||
if input_unit == "auto"
|
||||
else input_unit
|
||||
)
|
||||
return depth.astype(np.float32, order="K"), resolved_unit
|
||||
|
||||
@@ -68,7 +72,7 @@ def quantize_depth(
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
video_backend: str | None = "pyav",
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
input_unit: Literal["auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT] = "auto",
|
||||
) -> NDArray[np.uint16] | av.VideoFrame:
|
||||
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
|
||||
|
||||
@@ -106,8 +110,10 @@ def quantize_depth(
|
||||
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if input_unit not in ("auto", "m", "mm"):
|
||||
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
|
||||
if input_unit not in ("auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT):
|
||||
raise ValueError(
|
||||
f"input_unit must be 'auto', '{DEPTH_METER_UNIT}', or '{DEPTH_MILLIMETER_UNIT}', got {input_unit!r}"
|
||||
)
|
||||
|
||||
if isinstance(depth, torch.Tensor):
|
||||
depth = depth.detach().cpu().numpy()
|
||||
@@ -119,9 +125,13 @@ def quantize_depth(
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
|
||||
# Convert depth_min, depth_max, and shift to the resolved input unit.
|
||||
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
|
||||
depth_min_u = (
|
||||
np.float32(depth_min) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_min * _MM_PER_METRE)
|
||||
)
|
||||
depth_max_u = (
|
||||
np.float32(depth_max) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_max * _MM_PER_METRE)
|
||||
)
|
||||
shift_u = np.float32(shift) if resolved_unit == DEPTH_METER_UNIT else np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
# Normalization and quantization is performed in the resolved input unit.
|
||||
if use_log:
|
||||
@@ -149,7 +159,7 @@ def dequantize_depth(
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
output_unit: Literal[DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT] = DEPTH_MILLIMETER_UNIT,
|
||||
output_tensor: bool = True,
|
||||
output_channel_last: bool = False,
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
|
||||
@@ -184,8 +194,10 @@ def dequantize_depth(
|
||||
ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
if output_unit not in (DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT):
|
||||
raise ValueError(
|
||||
f"output_unit must be '{DEPTH_METER_UNIT}' or '{DEPTH_MILLIMETER_UNIT}', got {output_unit!r}"
|
||||
)
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
|
||||
@@ -219,7 +231,7 @@ def dequantize_depth(
|
||||
buf.clamp_(depth_min_m, depth_max_m)
|
||||
buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3)
|
||||
|
||||
if output_unit == "m":
|
||||
if output_unit == DEPTH_METER_UNIT:
|
||||
return buf if output_tensor else buf.cpu().numpy()
|
||||
|
||||
# mm path: round + clamp in float32, skipping the uint16 round-trip
|
||||
@@ -244,7 +256,7 @@ def dequantize_depth(
|
||||
np.clip(buf, depth_min_m, depth_max_m, out=buf)
|
||||
buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3)
|
||||
|
||||
if output_unit == "m":
|
||||
if output_unit == DEPTH_METER_UNIT:
|
||||
return torch.from_numpy(buf) if output_tensor else buf
|
||||
|
||||
np.multiply(buf, _MM_PER_METRE, out=buf)
|
||||
|
||||
@@ -23,7 +23,13 @@ import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.configs import DepthEncoderConfig
|
||||
from lerobot.configs.video import DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MIN, DEPTH_QMAX
|
||||
from lerobot.configs.video import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEPTH_METER_UNIT,
|
||||
DEPTH_MILLIMETER_UNIT,
|
||||
DEPTH_QMAX,
|
||||
)
|
||||
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
|
||||
from tests.fixtures.constants import (
|
||||
@@ -51,7 +57,7 @@ class TestQuantizeDequantize:
|
||||
"""Numerical contract of ``quantize_depth`` / ``dequantize_depth``."""
|
||||
|
||||
@pytest.mark.parametrize("use_log", [False, True])
|
||||
@pytest.mark.parametrize("output_unit", ["m", "mm"])
|
||||
@pytest.mark.parametrize("output_unit", [DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT])
|
||||
@pytest.mark.parametrize("output_channel_last", [False, True])
|
||||
def test_roundtrip(self, use_log, output_unit, output_channel_last):
|
||||
"""quantize → dequantize recovers depth; layout and unit are honored."""
|
||||
@@ -69,7 +75,7 @@ class TestQuantizeDequantize:
|
||||
assert recovered.shape == expected_shape
|
||||
|
||||
recovered_m = recovered.astype(np.float32)
|
||||
if output_unit == "mm":
|
||||
if output_unit == DEPTH_MILLIMETER_UNIT:
|
||||
recovered_m = recovered_m / 1000.0
|
||||
recovered_2d = recovered_m[..., 0] if output_channel_last else recovered_m[0]
|
||||
|
||||
@@ -86,7 +92,7 @@ class TestQuantizeDequantize:
|
||||
np.testing.assert_allclose(recovered_2d, depth, atol=tol)
|
||||
|
||||
@pytest.mark.parametrize("use_log", [False, True])
|
||||
@pytest.mark.parametrize("output_unit", ["m", "mm"])
|
||||
@pytest.mark.parametrize("output_unit", [DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT])
|
||||
def test_numpy_torch_agree(self, use_log, output_unit):
|
||||
"""Batched torch path produces the same values as the numpy path."""
|
||||
batch_size = 3
|
||||
@@ -101,7 +107,7 @@ class TestQuantizeDequantize:
|
||||
assert out.shape == (batch_size, 1, H, W)
|
||||
# ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step.
|
||||
# ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop.
|
||||
atol = 1e-5 if output_unit == "m" else 1.0
|
||||
atol = 1e-5 if output_unit == DEPTH_METER_UNIT else 1.0
|
||||
np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -117,7 +123,7 @@ class TestQuantizeDequantize:
|
||||
def test_input_layouts_accepted(self, input_shape, output_shape):
|
||||
"""All documented input layouts decode to the channel-first default."""
|
||||
quantized = np.full(input_shape, DEPTH_QMAX // 2, dtype=np.uint16)
|
||||
out = dequantize_depth(quantized, output_unit="m", output_tensor=False)
|
||||
out = dequantize_depth(quantized, output_unit=DEPTH_METER_UNIT, output_tensor=False)
|
||||
assert out.shape == output_shape
|
||||
|
||||
def test_pyav_frame_roundtrip(self):
|
||||
@@ -126,7 +132,7 @@ class TestQuantizeDequantize:
|
||||
frame = quantize_depth(depth, use_log=False, video_backend="pyav")
|
||||
assert isinstance(frame, av.VideoFrame)
|
||||
|
||||
recovered = dequantize_depth(frame, use_log=False, output_unit="m", output_tensor=False)
|
||||
recovered = dequantize_depth(frame, use_log=False, output_unit=DEPTH_METER_UNIT, output_tensor=False)
|
||||
assert recovered.shape == (1, H, W)
|
||||
tol = (DEFAULT_DEPTH_MAX - DEFAULT_DEPTH_MIN) / DEPTH_QMAX + 1e-3
|
||||
np.testing.assert_allclose(recovered[0], depth, atol=tol)
|
||||
|
||||
Reference in New Issue
Block a user