mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
feat(refactor): refactor DepthEncoderConfig quantization pipeline, so that the methods do not live in the config class. Add pixel format - channels validation.Move the default pixel format for depth in the config file.
This commit is contained in:
@@ -19,11 +19,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
@@ -45,7 +42,6 @@ VALID_VIDEO_CODECS: frozenset[str] = frozenset(
|
||||
# Aliases for legacy video codec names.
|
||||
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
|
||||
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
|
||||
@@ -57,6 +53,7 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
|
||||
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
|
||||
)
|
||||
|
||||
# Default depth quantization and encoding parameters.
|
||||
DEPTH_QUANT_BITS: int = 12
|
||||
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
|
||||
|
||||
@@ -64,6 +61,10 @@ DEFAULT_DEPTH_MIN: float = 0.01
|
||||
DEFAULT_DEPTH_MAX: float = 10.0
|
||||
DEFAULT_DEPTH_SHIFT: float = 3.5
|
||||
DEFAULT_DEPTH_USE_LOG: bool = True
|
||||
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
|
||||
|
||||
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
|
||||
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -99,6 +100,10 @@ class VideoEncoderConfig:
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Source-data channel count this encoder is expected to handle (3 for RGB,
|
||||
# 1 for depth, etc.)
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 3
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
@@ -151,7 +156,9 @@ class VideoEncoderConfig:
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import check_video_encoder_parameters_pyav
|
||||
|
||||
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
|
||||
check_video_encoder_parameters_pyav(
|
||||
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
|
||||
)
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
|
||||
@@ -258,19 +265,11 @@ class DepthEncoderConfig(VideoEncoderConfig):
|
||||
|
||||
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
|
||||
preset, ``extra_options``…) and adds the four parameters of the depth
|
||||
quantization pipeline (:func:`quantize_depth`). Inheritance — rather
|
||||
than composition — keeps the CLI flat: ``--dataset.depth_encoder_config.<field>``
|
||||
works identically to its RGB counterpart.
|
||||
quantizer.
|
||||
|
||||
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
|
||||
to ``"yuv420p12le"``, the most widely available 12-bit pixel format.
|
||||
For archive-grade lossless storage use ``vcodec="ffv1"`` together with
|
||||
``pix_fmt="gray12le"`` (and clear ``crf``/``preset`` to ``None`` since
|
||||
``ffv1`` doesn't expose those tuning knobs).
|
||||
to ``"gray12le"``.
|
||||
|
||||
The :attr:`is_depth_map` marker is class-fixed to ``True`` (``init=False``,
|
||||
so it's hidden from CLI and constructor args) and is what the reader
|
||||
side keys on to tell depth datasets from RGB ones.
|
||||
|
||||
Attributes:
|
||||
depth_min: Minimum depth in physical units (e.g. metres) represented
|
||||
@@ -282,28 +281,33 @@ class DepthEncoderConfig(VideoEncoderConfig):
|
||||
"""
|
||||
|
||||
vcodec: str = "hevc"
|
||||
pix_fmt: str = "yuv420p12le"
|
||||
pix_fmt: str = "gray12le"
|
||||
|
||||
depth_min: float = DEFAULT_DEPTH_MIN
|
||||
depth_max: float = DEFAULT_DEPTH_MAX
|
||||
shift: float = DEFAULT_DEPTH_SHIFT
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG
|
||||
|
||||
# Class invariant — kept out of ``__init__`` (and CLI) but persisted
|
||||
# via ``asdict`` into ``info.json`` for the reader to detect depth.
|
||||
is_depth_map: bool = field(default=True, init=False)
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 1
|
||||
|
||||
def quantize(self, depth: torch.Tensor | np.ndarray) -> torch.Tensor:
|
||||
"""Apply :func:`quantize_depth` bound to this config's parameters."""
|
||||
from lerobot.datasets.depth_utils import quantize_depth
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> DepthEncoderConfig:
|
||||
"""Reconstruct a :class:`DepthEncoderConfig` from a depth feature's ``info`` block.
|
||||
|
||||
return quantize_depth(depth, self.depth_min, self.depth_max, self.shift, self.use_log)
|
||||
Reuses :meth:`VideoEncoderConfig.from_video_info` for the base
|
||||
codec/tuning fields and then layers the depth-specific tuning
|
||||
(``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``) on top.
|
||||
Missing keys fall back to the class defaults.
|
||||
"""
|
||||
base = VideoEncoderConfig.from_video_info(video_info)
|
||||
kwargs: dict[str, Any] = {f.name: getattr(base, f.name) for f in fields(base) if f.init}
|
||||
|
||||
def dequantize(self, quantized: torch.Tensor | np.ndarray) -> torch.Tensor:
|
||||
"""Apply :func:`dequantize_depth` bound to this config's parameters."""
|
||||
from lerobot.datasets.depth_utils import dequantize_depth
|
||||
|
||||
return dequantize_depth(quantized, self.depth_min, self.depth_max, self.shift, self.use_log)
|
||||
video_info = video_info or {}
|
||||
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
|
||||
value = video_info.get(f"video.{name}")
|
||||
if value is not None:
|
||||
kwargs[name] = value
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def depth_encoder_defaults() -> DepthEncoderConfig:
|
||||
|
||||
@@ -22,6 +22,8 @@ from pathlib import Path
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.configs.video import DepthEncoderConfig
|
||||
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .depth_utils import dequantize_depth
|
||||
from .feature_utils import (
|
||||
@@ -87,17 +89,11 @@ class DatasetReader:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
if self._meta.depth_keys:
|
||||
# TODO(CarolinePascal): make this decent, this is awful.
|
||||
self._dequantize_depth_configs = {
|
||||
vid_key: {
|
||||
"depth_min": self._meta.features[vid_key]["info"]["video.depth_min"],
|
||||
"depth_max": self._meta.features[vid_key]["info"]["video.depth_max"],
|
||||
"shift": self._meta.features[vid_key]["info"]["video.shift"],
|
||||
"use_log": self._meta.features[vid_key]["info"]["video.use_log"],
|
||||
}
|
||||
for vid_key in self._meta.depth_keys
|
||||
}
|
||||
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
|
||||
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
|
||||
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
|
||||
for vid_key in self._meta.depth_keys
|
||||
}
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
@@ -263,8 +259,14 @@ class DatasetReader:
|
||||
is_depth=vid_key in self._meta.depth_keys,
|
||||
)
|
||||
if vid_key in self._meta.depth_keys:
|
||||
depth_encoder = self._depth_encoder_configs[vid_key]
|
||||
frames = dequantize_depth(
|
||||
frames, **self._dequantize_depth_configs[vid_key], output_tensor=True
|
||||
frames,
|
||||
depth_min=depth_encoder.depth_min,
|
||||
depth_max=depth_encoder.depth_max,
|
||||
shift=depth_encoder.shift,
|
||||
use_log=depth_encoder.use_log,
|
||||
output_tensor=True,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from numpy.typing import NDArray
|
||||
from lerobot.configs.video import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_PIX_FMT,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
DEPTH_QMAX,
|
||||
@@ -38,9 +39,6 @@ from .pyav_utils import write_u16_plane
|
||||
_MM_PER_METRE = 1000.0
|
||||
_UINT16_MAX = 65535
|
||||
|
||||
# Pixel format supported by the depth encode/decode helpers.
|
||||
DEPTH_PIX_FMT: str = "gray12le"
|
||||
|
||||
|
||||
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
"""Ensure ``log(depth_min + shift)`` is finite."""
|
||||
@@ -68,6 +66,7 @@ def quantize_depth(
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
video_backend: str | None = "pyav",
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> NDArray[np.uint16] | av.VideoFrame:
|
||||
@@ -136,7 +135,7 @@ def quantize_depth(
|
||||
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
|
||||
|
||||
if video_backend == "pyav":
|
||||
frame = av.VideoFrame.from_ndarray(quantized, format=DEPTH_PIX_FMT)
|
||||
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
|
||||
write_u16_plane(frame.planes[0], quantized)
|
||||
return frame
|
||||
else:
|
||||
@@ -149,6 +148,7 @@ def dequantize_depth(
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
output_tensor: bool = False,
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
|
||||
@@ -179,7 +179,7 @@ def dequantize_depth(
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
|
||||
if isinstance(quantized, av.VideoFrame):
|
||||
quantized = quantized.to_ndarray(format=DEPTH_PIX_FMT)
|
||||
quantized = quantized.to_ndarray(format=pix_fmt)
|
||||
|
||||
norm = np.asarray(quantized, dtype=np.float32, order="K") / DEPTH_QMAX
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ from lerobot.configs import (
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
from .depth_utils import DEPTH_PIX_FMT, quantize_depth
|
||||
from .depth_utils import quantize_depth
|
||||
from .pyav_utils import get_pix_fmt_channels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -151,7 +151,7 @@ def decode_video_frames_pyav(
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
if is_depth:
|
||||
arr = frame.to_ndarray(format=DEPTH_PIX_FMT) # (H, W) uint16
|
||||
arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
|
||||
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
|
||||
else:
|
||||
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
|
||||
|
||||
@@ -53,6 +53,7 @@ IMAGE_FEATURES = {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _make_dummy_stats(features: dict) -> dict:
|
||||
"""Create minimal episode stats matching the given features."""
|
||||
stats = {}
|
||||
|
||||
@@ -26,7 +26,7 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av # noqa: E402
|
||||
|
||||
from lerobot.configs import VALID_VIDEO_CODECS, VideoEncoderConfig, DepthEncoderConfig
|
||||
from lerobot.configs import VALID_VIDEO_CODECS, DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.datasets.image_writer import write_image
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pyav_utils import get_codec
|
||||
|
||||
Reference in New Issue
Block a user