feat(refactor): refactor DepthEncoderConfig quantization pipeline, so that the methods do not live in the config class. Add pixel format - channels validation.Move the default pixel format for depth in the config file.

This commit is contained in:
CarolinePascal
2026-05-22 02:06:37 +02:00
parent 7498f1cf61
commit 8e56797287
6 changed files with 57 additions and 50 deletions
+34 -30
View File
@@ -19,11 +19,8 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import torch
from dataclasses import dataclass, field, fields
from typing import Any, ClassVar
from lerobot.utils.import_utils import require_package
@@ -45,7 +42,6 @@ VALID_VIDEO_CODECS: frozenset[str] = frozenset(
# Aliases for legacy video codec names.
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
LIBSVTAV1_DEFAULT_PRESET: int = 12
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
@@ -57,6 +53,7 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
)
# Default depth quantization and encoding parameters.
DEPTH_QUANT_BITS: int = 12
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
@@ -64,6 +61,10 @@ DEFAULT_DEPTH_MIN: float = 0.01
DEFAULT_DEPTH_MAX: float = 10.0
DEFAULT_DEPTH_SHIFT: float = 3.5
DEFAULT_DEPTH_USE_LOG: bool = True
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
@dataclass
@@ -99,6 +100,10 @@ class VideoEncoderConfig:
video_backend: str = "pyav"
extra_options: dict[str, Any] = field(default_factory=dict)
# Source-data channel count this encoder is expected to handle (3 for RGB,
# 1 for depth, etc.)
_DEFAULT_CHANNELS: ClassVar[int] = 3
def __post_init__(self) -> None:
self.resolve_vcodec()
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
@@ -151,7 +156,9 @@ class VideoEncoderConfig:
require_package("av", extra="dataset")
from lerobot.datasets import check_video_encoder_parameters_pyav
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
check_video_encoder_parameters_pyav(
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
)
def resolve_vcodec(self) -> None:
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
@@ -258,19 +265,11 @@ class DepthEncoderConfig(VideoEncoderConfig):
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
preset, ``extra_options``…) and adds the four parameters of the depth
quantization pipeline (:func:`quantize_depth`). Inheritance — rather
than composition — keeps the CLI flat: ``--dataset.depth_encoder_config.<field>``
works identically to its RGB counterpart.
quantizer.
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
to ``"yuv420p12le"``, the most widely available 12-bit pixel format.
For archive-grade lossless storage use ``vcodec="ffv1"`` together with
``pix_fmt="gray12le"`` (and clear ``crf``/``preset`` to ``None`` since
``ffv1`` doesn't expose those tuning knobs).
to ``"gray12le"``.
The :attr:`is_depth_map` marker is class-fixed to ``True`` (``init=False``,
so it's hidden from CLI and constructor args) and is what the reader
side keys on to tell depth datasets from RGB ones.
Attributes:
depth_min: Minimum depth in physical units (e.g. metres) represented
@@ -282,28 +281,33 @@ class DepthEncoderConfig(VideoEncoderConfig):
"""
vcodec: str = "hevc"
pix_fmt: str = "yuv420p12le"
pix_fmt: str = "gray12le"
depth_min: float = DEFAULT_DEPTH_MIN
depth_max: float = DEFAULT_DEPTH_MAX
shift: float = DEFAULT_DEPTH_SHIFT
use_log: bool = DEFAULT_DEPTH_USE_LOG
# Class invariant — kept out of ``__init__`` (and CLI) but persisted
# via ``asdict`` into ``info.json`` for the reader to detect depth.
is_depth_map: bool = field(default=True, init=False)
_DEFAULT_CHANNELS: ClassVar[int] = 1
def quantize(self, depth: torch.Tensor | np.ndarray) -> torch.Tensor:
"""Apply :func:`quantize_depth` bound to this config's parameters."""
from lerobot.datasets.depth_utils import quantize_depth
@classmethod
def from_video_info(cls, video_info: dict | None) -> DepthEncoderConfig:
"""Reconstruct a :class:`DepthEncoderConfig` from a depth feature's ``info`` block.
return quantize_depth(depth, self.depth_min, self.depth_max, self.shift, self.use_log)
Reuses :meth:`VideoEncoderConfig.from_video_info` for the base
codec/tuning fields and then layers the depth-specific tuning
(``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``) on top.
Missing keys fall back to the class defaults.
"""
base = VideoEncoderConfig.from_video_info(video_info)
kwargs: dict[str, Any] = {f.name: getattr(base, f.name) for f in fields(base) if f.init}
def dequantize(self, quantized: torch.Tensor | np.ndarray) -> torch.Tensor:
"""Apply :func:`dequantize_depth` bound to this config's parameters."""
from lerobot.datasets.depth_utils import dequantize_depth
return dequantize_depth(quantized, self.depth_min, self.depth_max, self.shift, self.use_log)
video_info = video_info or {}
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
value = video_info.get(f"video.{name}")
if value is not None:
kwargs[name] = value
return cls(**kwargs)
def depth_encoder_defaults() -> DepthEncoderConfig:
+14 -12
View File
@@ -22,6 +22,8 @@ from pathlib import Path
import datasets
import torch
from lerobot.configs.video import DepthEncoderConfig
from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import dequantize_depth
from .feature_utils import (
@@ -87,17 +89,11 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
if self._meta.depth_keys:
# TODO(CarolinePascal): make this decent, this is awful.
self._dequantize_depth_configs = {
vid_key: {
"depth_min": self._meta.features[vid_key]["info"]["video.depth_min"],
"depth_max": self._meta.features[vid_key]["info"]["video.depth_max"],
"shift": self._meta.features[vid_key]["info"]["video.shift"],
"use_log": self._meta.features[vid_key]["info"]["video.use_log"],
}
for vid_key in self._meta.depth_keys
}
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
for vid_key in self._meta.depth_keys
}
def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient."""
@@ -263,8 +259,14 @@ class DatasetReader:
is_depth=vid_key in self._meta.depth_keys,
)
if vid_key in self._meta.depth_keys:
depth_encoder = self._depth_encoder_configs[vid_key]
frames = dequantize_depth(
frames, **self._dequantize_depth_configs[vid_key], output_tensor=True
frames,
depth_min=depth_encoder.depth_min,
depth_max=depth_encoder.depth_max,
shift=depth_encoder.shift,
use_log=depth_encoder.use_log,
output_tensor=True,
)
return vid_key, frames.squeeze(0)
+5 -5
View File
@@ -28,6 +28,7 @@ from numpy.typing import NDArray
from lerobot.configs.video import (
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_PIX_FMT,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
DEPTH_QMAX,
@@ -38,9 +39,6 @@ from .pyav_utils import write_u16_plane
_MM_PER_METRE = 1000.0
_UINT16_MAX = 65535
# Pixel format supported by the depth encode/decode helpers.
DEPTH_PIX_FMT: str = "gray12le"
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
"""Ensure ``log(depth_min + shift)`` is finite."""
@@ -68,6 +66,7 @@ def quantize_depth(
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
video_backend: str | None = "pyav",
input_unit: Literal["auto", "m", "mm"] = "auto",
) -> NDArray[np.uint16] | av.VideoFrame:
@@ -136,7 +135,7 @@ def quantize_depth(
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
if video_backend == "pyav":
frame = av.VideoFrame.from_ndarray(quantized, format=DEPTH_PIX_FMT)
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
write_u16_plane(frame.planes[0], quantized)
return frame
else:
@@ -149,6 +148,7 @@ def dequantize_depth(
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
output_unit: Literal["m", "mm"] = "mm",
output_tensor: bool = False,
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
@@ -179,7 +179,7 @@ def dequantize_depth(
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
if isinstance(quantized, av.VideoFrame):
quantized = quantized.to_ndarray(format=DEPTH_PIX_FMT)
quantized = quantized.to_ndarray(format=pix_fmt)
norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX
+2 -2
View File
@@ -44,7 +44,7 @@ from lerobot.configs import (
)
from lerobot.utils.import_utils import get_safe_default_video_backend
from .depth_utils import DEPTH_PIX_FMT, quantize_depth
from .depth_utils import quantize_depth
from .pyav_utils import get_pix_fmt_channels
logger = logging.getLogger(__name__)
@@ -151,7 +151,7 @@ def decode_video_frames_pyav(
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
if is_depth:
arr = frame.to_ndarray(format=DEPTH_PIX_FMT) # (H, W) uint16
arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
else:
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
+1
View File
@@ -53,6 +53,7 @@ IMAGE_FEATURES = {
},
}
def _make_dummy_stats(features: dict) -> dict:
"""Create minimal episode stats matching the given features."""
stats = {}
+1 -1
View File
@@ -26,7 +26,7 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
import av # noqa: E402
from lerobot.configs import VALID_VIDEO_CODECS, VideoEncoderConfig, DepthEncoderConfig
from lerobot.configs import VALID_VIDEO_CODECS, DepthEncoderConfig, VideoEncoderConfig
from lerobot.datasets.image_writer import write_image
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pyav_utils import get_codec