diff --git a/src/lerobot/configs/video.py b/src/lerobot/configs/video.py index 0eb5e8ed9..0c2d3e09f 100644 --- a/src/lerobot/configs/video.py +++ b/src/lerobot/configs/video.py @@ -19,11 +19,8 @@ from __future__ import annotations import logging -from dataclasses import dataclass, field -from typing import Any - -import numpy as np -import torch +from dataclasses import dataclass, field, fields +from typing import Any, ClassVar from lerobot.utils.import_utils import require_package @@ -45,7 +42,6 @@ VALID_VIDEO_CODECS: frozenset[str] = frozenset( # Aliases for legacy video codec names. VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"} - LIBSVTAV1_DEFAULT_PRESET: int = 12 # Keys persisted under ``features[*]["info"]`` as ``video.`` (from :class:`VideoEncoderConfig`). @@ -57,6 +53,7 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset( f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES ) +# Default depth quantization and encoding parameters. DEPTH_QUANT_BITS: int = 12 DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095 @@ -64,6 +61,10 @@ DEFAULT_DEPTH_MIN: float = 0.01 DEFAULT_DEPTH_MAX: float = 10.0 DEFAULT_DEPTH_SHIFT: float = 3.5 DEFAULT_DEPTH_USE_LOG: bool = True +DEFAULT_DEPTH_PIX_FMT: str = "gray12le" + +# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.``. +DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"}) @dataclass @@ -99,6 +100,10 @@ class VideoEncoderConfig: video_backend: str = "pyav" extra_options: dict[str, Any] = field(default_factory=dict) + # Source-data channel count this encoder is expected to handle (3 for RGB, + # 1 for depth, etc.) + _DEFAULT_CHANNELS: ClassVar[int] = 3 + def __post_init__(self) -> None: self.resolve_vcodec() # Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work". @@ -151,7 +156,9 @@ class VideoEncoderConfig: require_package("av", extra="dataset") from lerobot.datasets import check_video_encoder_parameters_pyav - check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options()) + check_video_encoder_parameters_pyav( + self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS + ) def resolve_vcodec(self) -> None: """Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder. @@ -258,19 +265,11 @@ class DepthEncoderConfig(VideoEncoderConfig): Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF, preset, ``extra_options``…) and adds the four parameters of the depth - quantization pipeline (:func:`quantize_depth`). Inheritance — rather - than composition — keeps the CLI flat: ``--dataset.depth_encoder_config.`` - works identically to its RGB counterpart. + quantizer. Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt`` - to ``"yuv420p12le"``, the most widely available 12-bit pixel format. - For archive-grade lossless storage use ``vcodec="ffv1"`` together with - ``pix_fmt="gray12le"`` (and clear ``crf``/``preset`` to ``None`` since - ``ffv1`` doesn't expose those tuning knobs). + to ``"gray12le"``. - The :attr:`is_depth_map` marker is class-fixed to ``True`` (``init=False``, - so it's hidden from CLI and constructor args) and is what the reader - side keys on to tell depth datasets from RGB ones. Attributes: depth_min: Minimum depth in physical units (e.g. metres) represented @@ -282,28 +281,33 @@ class DepthEncoderConfig(VideoEncoderConfig): """ vcodec: str = "hevc" - pix_fmt: str = "yuv420p12le" + pix_fmt: str = "gray12le" depth_min: float = DEFAULT_DEPTH_MIN depth_max: float = DEFAULT_DEPTH_MAX shift: float = DEFAULT_DEPTH_SHIFT use_log: bool = DEFAULT_DEPTH_USE_LOG - # Class invariant — kept out of ``__init__`` (and CLI) but persisted - # via ``asdict`` into ``info.json`` for the reader to detect depth. - is_depth_map: bool = field(default=True, init=False) + _DEFAULT_CHANNELS: ClassVar[int] = 1 - def quantize(self, depth: torch.Tensor | np.ndarray) -> torch.Tensor: - """Apply :func:`quantize_depth` bound to this config's parameters.""" - from lerobot.datasets.depth_utils import quantize_depth + @classmethod + def from_video_info(cls, video_info: dict | None) -> DepthEncoderConfig: + """Reconstruct a :class:`DepthEncoderConfig` from a depth feature's ``info`` block. - return quantize_depth(depth, self.depth_min, self.depth_max, self.shift, self.use_log) + Reuses :meth:`VideoEncoderConfig.from_video_info` for the base + codec/tuning fields and then layers the depth-specific tuning + (``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``) on top. + Missing keys fall back to the class defaults. + """ + base = VideoEncoderConfig.from_video_info(video_info) + kwargs: dict[str, Any] = {f.name: getattr(base, f.name) for f in fields(base) if f.init} - def dequantize(self, quantized: torch.Tensor | np.ndarray) -> torch.Tensor: - """Apply :func:`dequantize_depth` bound to this config's parameters.""" - from lerobot.datasets.depth_utils import dequantize_depth - - return dequantize_depth(quantized, self.depth_min, self.depth_max, self.shift, self.use_log) + video_info = video_info or {} + for name in DEPTH_ENCODER_INFO_FIELD_NAMES: + value = video_info.get(f"video.{name}") + if value is not None: + kwargs[name] = value + return cls(**kwargs) def depth_encoder_defaults() -> DepthEncoderConfig: diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index e6419c18e..ba94b2b99 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -22,6 +22,8 @@ from pathlib import Path import datasets import torch +from lerobot.configs.video import DepthEncoderConfig + from .dataset_metadata import LeRobotDatasetMetadata from .depth_utils import dequantize_depth from .feature_utils import ( @@ -87,17 +89,11 @@ class DatasetReader: check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) - if self._meta.depth_keys: - # TODO(CarolinePascal): make this decent, this is awful. - self._dequantize_depth_configs = { - vid_key: { - "depth_min": self._meta.features[vid_key]["info"]["video.depth_min"], - "depth_max": self._meta.features[vid_key]["info"]["video.depth_max"], - "shift": self._meta.features[vid_key]["info"]["video.shift"], - "use_log": self._meta.features[vid_key]["info"]["video.use_log"], - } - for vid_key in self._meta.depth_keys - } + ##TODO(CarolinePascal): Should we rather use a more lightweight structure ? + self._depth_encoder_configs: dict[str, DepthEncoderConfig] = { + vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info")) + for vid_key in self._meta.depth_keys + } def try_load(self) -> bool: """Attempt to load from local cache. Returns True if data is sufficient.""" @@ -263,8 +259,14 @@ class DatasetReader: is_depth=vid_key in self._meta.depth_keys, ) if vid_key in self._meta.depth_keys: + depth_encoder = self._depth_encoder_configs[vid_key] frames = dequantize_depth( - frames, **self._dequantize_depth_configs[vid_key], output_tensor=True + frames, + depth_min=depth_encoder.depth_min, + depth_max=depth_encoder.depth_max, + shift=depth_encoder.shift, + use_log=depth_encoder.use_log, + output_tensor=True, ) return vid_key, frames.squeeze(0) diff --git a/src/lerobot/datasets/depth_utils.py b/src/lerobot/datasets/depth_utils.py index 1c641f7f4..138df9009 100644 --- a/src/lerobot/datasets/depth_utils.py +++ b/src/lerobot/datasets/depth_utils.py @@ -28,6 +28,7 @@ from numpy.typing import NDArray from lerobot.configs.video import ( DEFAULT_DEPTH_MAX, DEFAULT_DEPTH_MIN, + DEFAULT_DEPTH_PIX_FMT, DEFAULT_DEPTH_SHIFT, DEFAULT_DEPTH_USE_LOG, DEPTH_QMAX, @@ -38,9 +39,6 @@ from .pyav_utils import write_u16_plane _MM_PER_METRE = 1000.0 _UINT16_MAX = 65535 -# Pixel format supported by the depth encode/decode helpers. -DEPTH_PIX_FMT: str = "gray12le" - def _validate_log_quant_params(depth_min: float, shift: float) -> None: """Ensure ``log(depth_min + shift)`` is finite.""" @@ -68,6 +66,7 @@ def quantize_depth( depth_max: float = DEFAULT_DEPTH_MAX, shift: float = DEFAULT_DEPTH_SHIFT, use_log: bool = DEFAULT_DEPTH_USE_LOG, + pix_fmt: str = DEFAULT_DEPTH_PIX_FMT, video_backend: str | None = "pyav", input_unit: Literal["auto", "m", "mm"] = "auto", ) -> NDArray[np.uint16] | av.VideoFrame: @@ -136,7 +135,7 @@ def quantize_depth( quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False) if video_backend == "pyav": - frame = av.VideoFrame.from_ndarray(quantized, format=DEPTH_PIX_FMT) + frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt) write_u16_plane(frame.planes[0], quantized) return frame else: @@ -149,6 +148,7 @@ def dequantize_depth( depth_max: float = DEFAULT_DEPTH_MAX, shift: float = DEFAULT_DEPTH_SHIFT, use_log: bool = DEFAULT_DEPTH_USE_LOG, + pix_fmt: str = DEFAULT_DEPTH_PIX_FMT, output_unit: Literal["m", "mm"] = "mm", output_tensor: bool = False, ) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor: @@ -179,7 +179,7 @@ def dequantize_depth( raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}") if isinstance(quantized, av.VideoFrame): - quantized = quantized.to_ndarray(format=DEPTH_PIX_FMT) + quantized = quantized.to_ndarray(format=pix_fmt) norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index ab741ba0a..ab373412a 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -44,7 +44,7 @@ from lerobot.configs import ( ) from lerobot.utils.import_utils import get_safe_default_video_backend -from .depth_utils import DEPTH_PIX_FMT, quantize_depth +from .depth_utils import quantize_depth from .pyav_utils import get_pix_fmt_channels logger = logging.getLogger(__name__) @@ -151,7 +151,7 @@ def decode_video_frames_pyav( if log_loaded_timestamps: logger.info(f"frame loaded at timestamp={current_ts:.4f}") if is_depth: - arr = frame.to_ndarray(format=DEPTH_PIX_FMT) # (H, W) uint16 + arr = frame.to_ndarray(format="gray12le") # (H, W) uint12 loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous()) else: arr = frame.to_ndarray(format="rgb24") # (H, W, 3) diff --git a/tests/datasets/test_dataset_metadata.py b/tests/datasets/test_dataset_metadata.py index efedf38d4..5e9fb884c 100644 --- a/tests/datasets/test_dataset_metadata.py +++ b/tests/datasets/test_dataset_metadata.py @@ -53,6 +53,7 @@ IMAGE_FEATURES = { }, } + def _make_dummy_stats(features: dict) -> dict: """Create minimal episode stats matching the given features.""" stats = {} diff --git a/tests/datasets/test_video_encoding.py b/tests/datasets/test_video_encoding.py index 00e3eed63..c9834d60b 100644 --- a/tests/datasets/test_video_encoding.py +++ b/tests/datasets/test_video_encoding.py @@ -26,7 +26,7 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])") import av # noqa: E402 -from lerobot.configs import VALID_VIDEO_CODECS, VideoEncoderConfig, DepthEncoderConfig +from lerobot.configs import VALID_VIDEO_CODECS, DepthEncoderConfig, VideoEncoderConfig from lerobot.datasets.image_writer import write_image from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pyav_utils import get_codec