feat(refactor): refactor DepthEncoderConfig quantization pipeline, so that the methods do not live in the config class. Add pixel format - channels validation.Move the default pixel format for depth in the config file.

This commit is contained in:
CarolinePascal
2026-05-22 02:06:37 +02:00
parent 7498f1cf61
commit 8e56797287
6 changed files with 57 additions and 50 deletions
+34 -30
View File
@@ -19,11 +19,8 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field, fields
from typing import Any from typing import Any, ClassVar
import numpy as np
import torch
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import require_package
@@ -45,7 +42,6 @@ VALID_VIDEO_CODECS: frozenset[str] = frozenset(
# Aliases for legacy video codec names. # Aliases for legacy video codec names.
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"} VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
LIBSVTAV1_DEFAULT_PRESET: int = 12 LIBSVTAV1_DEFAULT_PRESET: int = 12
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`). # Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
@@ -57,6 +53,7 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
) )
# Default depth quantization and encoding parameters.
DEPTH_QUANT_BITS: int = 12 DEPTH_QUANT_BITS: int = 12
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095 DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
@@ -64,6 +61,10 @@ DEFAULT_DEPTH_MIN: float = 0.01
DEFAULT_DEPTH_MAX: float = 10.0 DEFAULT_DEPTH_MAX: float = 10.0
DEFAULT_DEPTH_SHIFT: float = 3.5 DEFAULT_DEPTH_SHIFT: float = 3.5
DEFAULT_DEPTH_USE_LOG: bool = True DEFAULT_DEPTH_USE_LOG: bool = True
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
@dataclass @dataclass
@@ -99,6 +100,10 @@ class VideoEncoderConfig:
video_backend: str = "pyav" video_backend: str = "pyav"
extra_options: dict[str, Any] = field(default_factory=dict) extra_options: dict[str, Any] = field(default_factory=dict)
# Source-data channel count this encoder is expected to handle (3 for RGB,
# 1 for depth, etc.)
_DEFAULT_CHANNELS: ClassVar[int] = 3
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.resolve_vcodec() self.resolve_vcodec()
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work". # Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
@@ -151,7 +156,9 @@ class VideoEncoderConfig:
require_package("av", extra="dataset") require_package("av", extra="dataset")
from lerobot.datasets import check_video_encoder_parameters_pyav from lerobot.datasets import check_video_encoder_parameters_pyav
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options()) check_video_encoder_parameters_pyav(
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
)
def resolve_vcodec(self) -> None: def resolve_vcodec(self) -> None:
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder. """Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
@@ -258,19 +265,11 @@ class DepthEncoderConfig(VideoEncoderConfig):
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF, Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
preset, ``extra_options``…) and adds the four parameters of the depth preset, ``extra_options``…) and adds the four parameters of the depth
quantization pipeline (:func:`quantize_depth`). Inheritance — rather quantizer.
than composition — keeps the CLI flat: ``--dataset.depth_encoder_config.<field>``
works identically to its RGB counterpart.
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt`` Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
to ``"yuv420p12le"``, the most widely available 12-bit pixel format. to ``"gray12le"``.
For archive-grade lossless storage use ``vcodec="ffv1"`` together with
``pix_fmt="gray12le"`` (and clear ``crf``/``preset`` to ``None`` since
``ffv1`` doesn't expose those tuning knobs).
The :attr:`is_depth_map` marker is class-fixed to ``True`` (``init=False``,
so it's hidden from CLI and constructor args) and is what the reader
side keys on to tell depth datasets from RGB ones.
Attributes: Attributes:
depth_min: Minimum depth in physical units (e.g. metres) represented depth_min: Minimum depth in physical units (e.g. metres) represented
@@ -282,28 +281,33 @@ class DepthEncoderConfig(VideoEncoderConfig):
""" """
vcodec: str = "hevc" vcodec: str = "hevc"
pix_fmt: str = "yuv420p12le" pix_fmt: str = "gray12le"
depth_min: float = DEFAULT_DEPTH_MIN depth_min: float = DEFAULT_DEPTH_MIN
depth_max: float = DEFAULT_DEPTH_MAX depth_max: float = DEFAULT_DEPTH_MAX
shift: float = DEFAULT_DEPTH_SHIFT shift: float = DEFAULT_DEPTH_SHIFT
use_log: bool = DEFAULT_DEPTH_USE_LOG use_log: bool = DEFAULT_DEPTH_USE_LOG
# Class invariant — kept out of ``__init__`` (and CLI) but persisted _DEFAULT_CHANNELS: ClassVar[int] = 1
# via ``asdict`` into ``info.json`` for the reader to detect depth.
is_depth_map: bool = field(default=True, init=False)
def quantize(self, depth: torch.Tensor | np.ndarray) -> torch.Tensor: @classmethod
"""Apply :func:`quantize_depth` bound to this config's parameters.""" def from_video_info(cls, video_info: dict | None) -> DepthEncoderConfig:
from lerobot.datasets.depth_utils import quantize_depth """Reconstruct a :class:`DepthEncoderConfig` from a depth feature's ``info`` block.
return quantize_depth(depth, self.depth_min, self.depth_max, self.shift, self.use_log) Reuses :meth:`VideoEncoderConfig.from_video_info` for the base
codec/tuning fields and then layers the depth-specific tuning
(``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``) on top.
Missing keys fall back to the class defaults.
"""
base = VideoEncoderConfig.from_video_info(video_info)
kwargs: dict[str, Any] = {f.name: getattr(base, f.name) for f in fields(base) if f.init}
def dequantize(self, quantized: torch.Tensor | np.ndarray) -> torch.Tensor: video_info = video_info or {}
"""Apply :func:`dequantize_depth` bound to this config's parameters.""" for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
from lerobot.datasets.depth_utils import dequantize_depth value = video_info.get(f"video.{name}")
if value is not None:
return dequantize_depth(quantized, self.depth_min, self.depth_max, self.shift, self.use_log) kwargs[name] = value
return cls(**kwargs)
def depth_encoder_defaults() -> DepthEncoderConfig: def depth_encoder_defaults() -> DepthEncoderConfig:
+14 -12
View File
@@ -22,6 +22,8 @@ from pathlib import Path
import datasets import datasets
import torch import torch
from lerobot.configs.video import DepthEncoderConfig
from .dataset_metadata import LeRobotDatasetMetadata from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import dequantize_depth from .depth_utils import dequantize_depth
from .feature_utils import ( from .feature_utils import (
@@ -87,17 +89,11 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
if self._meta.depth_keys: ##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
# TODO(CarolinePascal): make this decent, this is awful. self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
self._dequantize_depth_configs = { vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
vid_key: { for vid_key in self._meta.depth_keys
"depth_min": self._meta.features[vid_key]["info"]["video.depth_min"], }
"depth_max": self._meta.features[vid_key]["info"]["video.depth_max"],
"shift": self._meta.features[vid_key]["info"]["video.shift"],
"use_log": self._meta.features[vid_key]["info"]["video.use_log"],
}
for vid_key in self._meta.depth_keys
}
def try_load(self) -> bool: def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient.""" """Attempt to load from local cache. Returns True if data is sufficient."""
@@ -263,8 +259,14 @@ class DatasetReader:
is_depth=vid_key in self._meta.depth_keys, is_depth=vid_key in self._meta.depth_keys,
) )
if vid_key in self._meta.depth_keys: if vid_key in self._meta.depth_keys:
depth_encoder = self._depth_encoder_configs[vid_key]
frames = dequantize_depth( frames = dequantize_depth(
frames, **self._dequantize_depth_configs[vid_key], output_tensor=True frames,
depth_min=depth_encoder.depth_min,
depth_max=depth_encoder.depth_max,
shift=depth_encoder.shift,
use_log=depth_encoder.use_log,
output_tensor=True,
) )
return vid_key, frames.squeeze(0) return vid_key, frames.squeeze(0)
+5 -5
View File
@@ -28,6 +28,7 @@ from numpy.typing import NDArray
from lerobot.configs.video import ( from lerobot.configs.video import (
DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_MIN, DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_PIX_FMT,
DEFAULT_DEPTH_SHIFT, DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG, DEFAULT_DEPTH_USE_LOG,
DEPTH_QMAX, DEPTH_QMAX,
@@ -38,9 +39,6 @@ from .pyav_utils import write_u16_plane
_MM_PER_METRE = 1000.0 _MM_PER_METRE = 1000.0
_UINT16_MAX = 65535 _UINT16_MAX = 65535
# Pixel format supported by the depth encode/decode helpers.
DEPTH_PIX_FMT: str = "gray12le"
def _validate_log_quant_params(depth_min: float, shift: float) -> None: def _validate_log_quant_params(depth_min: float, shift: float) -> None:
"""Ensure ``log(depth_min + shift)`` is finite.""" """Ensure ``log(depth_min + shift)`` is finite."""
@@ -68,6 +66,7 @@ def quantize_depth(
depth_max: float = DEFAULT_DEPTH_MAX, depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT, shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG, use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
video_backend: str | None = "pyav", video_backend: str | None = "pyav",
input_unit: Literal["auto", "m", "mm"] = "auto", input_unit: Literal["auto", "m", "mm"] = "auto",
) -> NDArray[np.uint16] | av.VideoFrame: ) -> NDArray[np.uint16] | av.VideoFrame:
@@ -136,7 +135,7 @@ def quantize_depth(
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False) quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
if video_backend == "pyav": if video_backend == "pyav":
frame = av.VideoFrame.from_ndarray(quantized, format=DEPTH_PIX_FMT) frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
write_u16_plane(frame.planes[0], quantized) write_u16_plane(frame.planes[0], quantized)
return frame return frame
else: else:
@@ -149,6 +148,7 @@ def dequantize_depth(
depth_max: float = DEFAULT_DEPTH_MAX, depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT, shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG, use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
output_unit: Literal["m", "mm"] = "mm", output_unit: Literal["m", "mm"] = "mm",
output_tensor: bool = False, output_tensor: bool = False,
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor: ) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
@@ -179,7 +179,7 @@ def dequantize_depth(
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}") raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
if isinstance(quantized, av.VideoFrame): if isinstance(quantized, av.VideoFrame):
quantized = quantized.to_ndarray(format=DEPTH_PIX_FMT) quantized = quantized.to_ndarray(format=pix_fmt)
norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX
+2 -2
View File
@@ -44,7 +44,7 @@ from lerobot.configs import (
) )
from lerobot.utils.import_utils import get_safe_default_video_backend from lerobot.utils.import_utils import get_safe_default_video_backend
from .depth_utils import DEPTH_PIX_FMT, quantize_depth from .depth_utils import quantize_depth
from .pyav_utils import get_pix_fmt_channels from .pyav_utils import get_pix_fmt_channels
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -151,7 +151,7 @@ def decode_video_frames_pyav(
if log_loaded_timestamps: if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}") logger.info(f"frame loaded at timestamp={current_ts:.4f}")
if is_depth: if is_depth:
arr = frame.to_ndarray(format=DEPTH_PIX_FMT) # (H, W) uint16 arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous()) loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
else: else:
arr = frame.to_ndarray(format="rgb24") # (H, W, 3) arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
+1
View File
@@ -53,6 +53,7 @@ IMAGE_FEATURES = {
}, },
} }
def _make_dummy_stats(features: dict) -> dict: def _make_dummy_stats(features: dict) -> dict:
"""Create minimal episode stats matching the given features.""" """Create minimal episode stats matching the given features."""
stats = {} stats = {}
+1 -1
View File
@@ -26,7 +26,7 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
import av # noqa: E402 import av # noqa: E402
from lerobot.configs import VALID_VIDEO_CODECS, VideoEncoderConfig, DepthEncoderConfig from lerobot.configs import VALID_VIDEO_CODECS, DepthEncoderConfig, VideoEncoderConfig
from lerobot.datasets.image_writer import write_image from lerobot.datasets.image_writer import write_image
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pyav_utils import get_codec from lerobot.datasets.pyav_utils import get_codec