mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-12 14:09:51 +00:00
Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6610819c4a | |||
| 6ab5e667c0 | |||
| 55ff277af9 | |||
| 71546887c5 | |||
| 1cc9fe0139 | |||
| f1859d8c65 | |||
| 3d67784f8c | |||
| 621f4191be | |||
| de44bc9711 | |||
| bb47f5c463 | |||
| e06d83fbfd | |||
| 6179b5fd7e | |||
| 72ce8b099c | |||
| 2efa8e5540 | |||
| aff1252b90 | |||
| bf78fb11d3 | |||
| d8eea83a89 | |||
| b183571586 | |||
| 6bf4ffaabc | |||
| 45683aeccc | |||
| 5b91fe5d1a | |||
| ca9fad1f12 | |||
| efdec0ed3a | |||
| 1ba832e6e2 | |||
| 44280c4c56 | |||
| 8825bb77e4 | |||
| ab6911b781 | |||
| 929113c0d7 | |||
| b4e7d4c63a | |||
| 3f1b6f1ce1 | |||
| 226efd51f2 | |||
| c49c7536f2 | |||
| 93d1fac4d1 | |||
| 01081d2156 | |||
| 119169e417 | |||
| 7402980d7d |
@@ -82,7 +82,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.fps": 30,
|
||||
"video.channels": 3,
|
||||
"video.is_depth_map": false,
|
||||
"is_depth_map": false,
|
||||
"video.g": 2,
|
||||
"video.crf": 30,
|
||||
"video.preset": "fast",
|
||||
@@ -97,7 +97,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
|
||||
Two sources contribute to the `info` block:
|
||||
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -105,8 +105,9 @@ def raw_observation_to_observation(
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
"""Minimal preprocessing to turn RGB uint8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
if image.dtype == torch.uint8:
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
@@ -436,7 +436,7 @@ class OpenCVCamera(Camera):
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
1. Reads a color frame (blocking call)
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
@@ -445,8 +445,9 @@ class OpenCVCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
@@ -484,6 +485,8 @@ class OpenCVCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive():
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
@@ -268,13 +268,13 @@ class RealSenseCamera(Camera):
|
||||
)
|
||||
|
||||
if len(found_devices) > 1:
|
||||
serial_numbers = [dev["serial_number"] for dev in found_devices]
|
||||
serial_numbers = [dev["id"] for dev in found_devices]
|
||||
raise ValueError(
|
||||
f"Multiple RealSense cameras found with name '{name}'. "
|
||||
f"Please use a unique serial number instead. Found SNs: {serial_numbers}"
|
||||
)
|
||||
|
||||
serial_number = str(found_devices[0]["serial_number"])
|
||||
serial_number = str(found_devices[0]["id"])
|
||||
return serial_number
|
||||
|
||||
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
|
||||
@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The depth map as a NumPy array (height, width)
|
||||
of type `np.uint16` (raw depth values in millimeters) and rotation.
|
||||
np.ndarray: The depth map as a NumPy array (height, width, 1)
|
||||
of type `np.uint16` (raw depth values in millimeters).
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
@@ -465,8 +465,8 @@ class RealSenseCamera(Camera):
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame with 500ms timeout
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
1. Reads a color/depth frame (blocking call with 10s timeout)
|
||||
2. Stores result in latest_color_frame/latest_depth_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -474,8 +474,9 @@ class RealSenseCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
@@ -486,6 +487,8 @@ class RealSenseCamera(Camera):
|
||||
depth_frame_raw = frame.get_depth_frame()
|
||||
depth_frame = np.asanyarray(depth_frame_raw.get_data())
|
||||
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
|
||||
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
|
||||
processed_depth_frame = processed_depth_frame[..., np.newaxis]
|
||||
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
@@ -522,6 +525,8 @@ class RealSenseCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive(): # pragma: no cover
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
@@ -532,7 +537,6 @@ class RealSenseCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -575,7 +579,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
@@ -611,6 +614,71 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""Read the latest depth frame asynchronously, in metric meters.
|
||||
|
||||
Mirrors :meth:`async_read` but returns the depth stream rather than the
|
||||
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
the background read thread is not running.
|
||||
TimeoutError: If no frame becomes available within ``timeout_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if depth_frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
|
||||
|
||||
return depth_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent depth frame in metric meters (peeking).
|
||||
|
||||
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
|
||||
Output is ``np.uint16`` of shape ``(H, W, 1)`` in millimeters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
no depth frame has been captured yet.
|
||||
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if depth_frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any depth frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return depth_frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
|
||||
@@ -249,8 +249,9 @@ class ZMQCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
@@ -292,6 +293,8 @@ class ZMQCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive():
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
@@ -35,8 +35,11 @@ from .types import (
|
||||
from .video import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VIDEO_ENCODER_INFO_KEYS,
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
encoder_config_from_video_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -57,8 +60,12 @@ __all__ = [
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
"VideoEncoderConfig",
|
||||
"DepthEncoderConfig",
|
||||
# Defaults
|
||||
"camera_encoder_defaults",
|
||||
"depth_encoder_defaults",
|
||||
# Factories
|
||||
"encoder_config_from_video_info",
|
||||
# Constants
|
||||
"VALID_VIDEO_CODECS",
|
||||
"VIDEO_ENCODER_INFO_KEYS",
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from .video import VideoEncoderConfig, camera_encoder_defaults
|
||||
from .video import DepthEncoderConfig, VideoEncoderConfig, camera_encoder_defaults, depth_encoder_defaults
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -60,6 +60,8 @@ class DatasetRecordConfig:
|
||||
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
|
||||
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
# Video encoder settings for depth-map MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys.
|
||||
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
|
||||
@@ -20,7 +20,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar, Self
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
@@ -36,11 +36,12 @@ HW_VIDEO_CODECS = [
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
|
||||
{"h264", "hevc", "libsvtav1", "ffv1", "auto", *HW_VIDEO_CODECS}
|
||||
)
|
||||
# Aliases for legacy video codec names.
|
||||
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
|
||||
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
|
||||
@@ -52,6 +53,19 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
|
||||
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
|
||||
)
|
||||
|
||||
# Default depth quantization and encoding parameters.
|
||||
DEPTH_QUANT_BITS: int = 12
|
||||
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
|
||||
|
||||
DEFAULT_DEPTH_MIN: float = 0.01
|
||||
DEFAULT_DEPTH_MAX: float = 10.0
|
||||
DEFAULT_DEPTH_SHIFT: float = 3.5
|
||||
DEFAULT_DEPTH_USE_LOG: bool = True
|
||||
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
|
||||
|
||||
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
|
||||
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
@@ -86,6 +100,10 @@ class VideoEncoderConfig:
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Source-data channel count this encoder is expected to handle (3 for RGB,
|
||||
# 1 for depth, etc.)
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 3
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
@@ -94,9 +112,9 @@ class VideoEncoderConfig:
|
||||
self.validate()
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
|
||||
"""Parse the ``video.*`` keys of a feature ``info`` block into
|
||||
constructor kwargs.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
@@ -115,7 +133,15 @@ class VideoEncoderConfig:
|
||||
continue
|
||||
kwargs[field_name] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> Self:
|
||||
"""Reconstruct an encoder config from a video feature's ``info`` block.
|
||||
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
"""
|
||||
return cls(**cls._kwargs_from_video_info(video_info))
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of available encoders based on the specified video backend.
|
||||
@@ -138,7 +164,9 @@ class VideoEncoderConfig:
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import check_video_encoder_parameters_pyav
|
||||
|
||||
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
|
||||
check_video_encoder_parameters_pyav(
|
||||
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
|
||||
)
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
|
||||
@@ -218,6 +246,10 @@ class VideoEncoderConfig:
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "ffv1":
|
||||
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
|
||||
# are not meaningful.
|
||||
set_if("threads", encoder_threads)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
@@ -233,3 +265,75 @@ class VideoEncoderConfig:
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthEncoderConfig(VideoEncoderConfig):
|
||||
"""Encoder configuration for depth-map streams.
|
||||
|
||||
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
|
||||
preset, ``extra_options``…) and adds the four parameters of the depth
|
||||
quantizer.
|
||||
|
||||
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
|
||||
to ``"gray12le"``.
|
||||
|
||||
|
||||
Attributes:
|
||||
depth_min: Minimum depth in physical units (e.g. metres) represented
|
||||
by quantum ``0``.
|
||||
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
|
||||
shift: Pre-log offset for numerical stability near zero.
|
||||
use_log: ``True`` for logarithmic quantization (default; matches
|
||||
sensor error profile), ``False`` for linear.
|
||||
"""
|
||||
|
||||
vcodec: str = "hevc"
|
||||
pix_fmt: str = "gray12le"
|
||||
|
||||
depth_min: float = DEFAULT_DEPTH_MIN
|
||||
depth_max: float = DEFAULT_DEPTH_MAX
|
||||
shift: float = DEFAULT_DEPTH_SHIFT
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG
|
||||
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 1
|
||||
|
||||
@classmethod
|
||||
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
|
||||
"""Layer the depth-specific tuning (``depth_min`` / ``depth_max`` /
|
||||
``shift`` / ``use_log``) on top of the base parser. Missing keys
|
||||
fall back to the class defaults.
|
||||
"""
|
||||
kwargs = super()._kwargs_from_video_info(video_info)
|
||||
video_info = video_info or {}
|
||||
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
|
||||
value = video_info.get(f"video.{name}")
|
||||
if value is not None:
|
||||
kwargs[name] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def depth_encoder_defaults() -> DepthEncoderConfig:
|
||||
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
|
||||
return DepthEncoderConfig()
|
||||
|
||||
|
||||
def encoder_config_from_video_info(video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Build the appropriate encoder config from a feature's ``info`` block.
|
||||
|
||||
Dispatches to :class:`DepthEncoderConfig` when the dict marks the feature
|
||||
as a depth map and to :class:`VideoEncoderConfig`
|
||||
otherwise.
|
||||
|
||||
Args:
|
||||
video_info: A feature's ``info`` dict as persisted in ``info.json``,
|
||||
or ``None`` (treated as an empty dict).
|
||||
|
||||
Returns:
|
||||
A :class:`DepthEncoderConfig` for depth features, otherwise a
|
||||
:class:`VideoEncoderConfig`.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
is_depth = bool(video_info.get("is_depth_map") or video_info.get("video.is_depth_map"))
|
||||
cls: type[VideoEncoderConfig] = DepthEncoderConfig if is_depth else VideoEncoderConfig
|
||||
return cls.from_video_info(video_info)
|
||||
|
||||
@@ -529,8 +529,9 @@ def compute_episode_stats(
|
||||
)
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
normalization_factor = 255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
@@ -550,8 +551,10 @@ def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||||
if key == "count" and value.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||||
|
||||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||||
if "image" in feature_key and key != "count" and value.shape not in ((3, 1, 1), (1, 1, 1)):
|
||||
raise ValueError(
|
||||
f"Shape of quantile '{key}' must be (3,1,1) or (1,1,1) but is {value.shape} instead."
|
||||
)
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterable
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -337,6 +337,25 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def depth_keys(self) -> list[str]:
|
||||
"""Keys to access depth-map modalities stored as videos or images.
|
||||
|
||||
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
|
||||
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
|
||||
"""
|
||||
|
||||
def _is_depth(ft: dict) -> bool:
|
||||
info = ft.get("info") or {}
|
||||
video_info = ft.get("video_info") or {}
|
||||
return (
|
||||
info.get("is_depth_map", False)
|
||||
or info.get("video.is_depth_map", False)
|
||||
or video_info.get("video.is_depth_map", False)
|
||||
)
|
||||
|
||||
return [key for key, ft in self.features.items() if _is_depth(ft)]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
@@ -580,7 +599,8 @@ class LeRobotDatasetMetadata:
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
preserve_keys: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
|
||||
@@ -590,19 +610,27 @@ class LeRobotDatasetMetadata:
|
||||
Args:
|
||||
video_key: If provided, only update this video key. Otherwise update
|
||||
all video keys in the dataset.
|
||||
camera_encoder: Encoder configuration used to produce the
|
||||
video_encoder: Encoder configuration used to produce the
|
||||
videos. When provided, its fields are recorded as
|
||||
``video.<field>`` entries alongside the stream-derived
|
||||
``video.*`` entries (see :func:`get_video_info`).
|
||||
preserve_keys: Optional iterable of ``info`` keys whose existing
|
||||
values must be kept as-is.
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
preserve_set = set(preserve_keys or ())
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
existing = self.features[key].get("info") or {}
|
||||
# Skip only if real video info has already been written. The ``is_depth_map`` entry (created at feature creation) is not blocking.
|
||||
if set(existing.keys()) - {"is_depth_map"}:
|
||||
continue
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
new_info = get_video_info(video_path, video_encoder=video_encoder)
|
||||
new_info = {k: v for k, v in new_info.items() if k not in preserve_set}
|
||||
self.info.features[key]["info"] = {**existing, **new_info}
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
|
||||
@@ -22,7 +22,10 @@ from pathlib import Path
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.configs.video import DepthEncoderConfig
|
||||
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .depth_utils import dequantize_depth
|
||||
from .feature_utils import (
|
||||
check_delta_timestamps,
|
||||
get_delta_indices,
|
||||
@@ -86,6 +89,12 @@ class DatasetReader:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
|
||||
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
|
||||
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
|
||||
for vid_key in self._meta.depth_keys
|
||||
}
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
try:
|
||||
@@ -247,7 +256,17 @@ class DatasetReader:
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
is_depth=vid_key in self._meta.depth_keys,
|
||||
)
|
||||
if vid_key in self._meta.depth_keys:
|
||||
depth_encoder = self._depth_encoder_configs[vid_key]
|
||||
frames = dequantize_depth(
|
||||
frames,
|
||||
depth_min=depth_encoder.depth_min,
|
||||
depth_max=depth_encoder.depth_max,
|
||||
shift=depth_encoder.shift,
|
||||
use_log=depth_encoder.use_log,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
@@ -36,7 +36,8 @@ import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, DepthEncoderConfig, encoder_config_from_video_info, depth_encoder_defaults
|
||||
from lerobot.configs.video import DEPTH_ENCODER_INFO_FIELD_NAMES
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
|
||||
@@ -726,7 +727,7 @@ def _copy_and_reindex_videos(
|
||||
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
logging.info(f"Processing videos for {video_key}")
|
||||
camera_encoder = VideoEncoderConfig.from_video_info(
|
||||
video_encoder = encoder_config_from_video_info(
|
||||
src_dataset.meta.info.features.get(video_key, {}).get("info")
|
||||
)
|
||||
|
||||
@@ -810,7 +811,7 @@ def _copy_and_reindex_videos(
|
||||
dst_video_path,
|
||||
episodes_to_keep_ranges,
|
||||
src_dataset.meta.fps,
|
||||
camera_encoder,
|
||||
video_encoder,
|
||||
)
|
||||
|
||||
cumulative_ts = 0.0
|
||||
@@ -1190,7 +1191,10 @@ def _save_batch_episodes_images(
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key_param]
|
||||
# Use global frame index for naming
|
||||
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
|
||||
if img_key_param in dataset.meta.depth_keys:
|
||||
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.tiff"), compression="raw")
|
||||
else:
|
||||
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
episode_durations = []
|
||||
@@ -1281,7 +1285,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
video_encoder: VideoEncoderConfig,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
@@ -1295,7 +1299,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
camera_encoder: Video encoder settings used for calibration encoding.
|
||||
video_encoder: Video encoder settings used for calibration encoding.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
@@ -1331,7 +1335,7 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=video_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1604,6 +1608,7 @@ def recompute_stats(
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
#TODO: enable image and video stats re-computation
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
@@ -1650,6 +1655,7 @@ def convert_image_to_video_dataset(
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: VideoEncoderConfig | None = None,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
@@ -1676,6 +1682,8 @@ def convert_image_to_video_dataset(
|
||||
"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
if depth_encoder is None:
|
||||
depth_encoder = depth_encoder_defaults()
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
@@ -1701,8 +1709,7 @@ def convert_image_to_video_dataset(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(
|
||||
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
|
||||
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
|
||||
f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}"
|
||||
)
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
@@ -1765,6 +1772,8 @@ def convert_image_to_video_dataset(
|
||||
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
|
||||
|
||||
for img_key in tqdm(img_keys, desc="Processing cameras"):
|
||||
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
|
||||
|
||||
# Estimate size per frame by encoding a small calibration sample
|
||||
# This provides accurate compression ratio for the specific codec parameters
|
||||
size_per_frame_mb = _estimate_frame_size_via_calibration(
|
||||
@@ -1773,7 +1782,7 @@ def convert_image_to_video_dataset(
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=target_encoder,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
@@ -1815,7 +1824,7 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=target_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1862,7 +1871,7 @@ def convert_image_to_video_dataset(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, camera_encoder=camera_encoder
|
||||
video_path, video_encoder=camera_encoder
|
||||
)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
@@ -1890,11 +1899,11 @@ def convert_image_to_video_dataset(
|
||||
|
||||
def _reencode_video_worker(args: tuple) -> Path:
|
||||
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
|
||||
video_path, camera_encoder, encoder_threads = args
|
||||
video_path, video_encoder, encoder_threads = args
|
||||
reencode_video(
|
||||
input_video_path=video_path,
|
||||
output_video_path=video_path,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=video_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
@@ -1903,7 +1912,8 @@ def _reencode_video_worker(args: tuple) -> Path:
|
||||
|
||||
def reencode_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
@@ -1914,8 +1924,11 @@ def reencode_dataset(
|
||||
Args:
|
||||
dataset: An existing :class:`LeRobotDataset` whose videos will be
|
||||
re-encoded.
|
||||
camera_encoder: Target encoder configuration applied to every video
|
||||
file.
|
||||
camera_encoder: Target encoder configuration applied to every RGB video
|
||||
file. If ``None``, re-encoding is skipped for RGB videos.
|
||||
depth_encoder: Target encoder configuration applied to every depth video
|
||||
file. If ``None``, re-encoding is skipped for depth videos.
|
||||
Quantization parameters will not override the ones in the current dataset.
|
||||
encoder_threads: Per-encoder thread count forwarded to
|
||||
:func:`reencode_video`. ``None`` lets the codec decide.
|
||||
num_workers: Number of parallel processes. ``None`` or ``0`` means
|
||||
@@ -1927,23 +1940,31 @@ def reencode_dataset(
|
||||
on disk.
|
||||
"""
|
||||
meta = dataset.meta
|
||||
video_paths_list = []
|
||||
video_keys_encoders_dict = {}
|
||||
video_keys_paths_dict = {}
|
||||
|
||||
if camera_encoder is None and depth_encoder is None:
|
||||
raise ValueError("Either camera_encoder or depth_encoder must be provided")
|
||||
|
||||
# Only re-encode if the videos are not already encoded with the given video encoding parameters
|
||||
for video_key in meta.video_keys:
|
||||
current_info = meta.info.features[video_key].get("info", {})
|
||||
current_encoder = VideoEncoderConfig.from_video_info(current_info)
|
||||
if current_encoder != camera_encoder:
|
||||
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
current_encoder = encoder_config_from_video_info(current_info)
|
||||
target_encoder = depth_encoder if video_key in meta.depth_keys else camera_encoder
|
||||
if target_encoder is None:
|
||||
logging.info(f"No encoder provided for {video_key} video. Skipping re-encoding.")
|
||||
elif current_encoder != target_encoder:
|
||||
video_keys_paths_dict[video_key] = (meta.root / VIDEO_DIR / video_key).rglob("*.mp4")
|
||||
video_keys_encoders_dict[video_key] = target_encoder
|
||||
else:
|
||||
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
|
||||
logging.info(f"{video_key} videos are already encoded with {target_encoder}. Nothing to do.")
|
||||
|
||||
if len(video_paths_list) == 0:
|
||||
if len(video_keys_paths_dict) == 0:
|
||||
logging.warning("Dataset has no videos to re-encode.")
|
||||
return dataset
|
||||
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
|
||||
logging.info(f"Re-encoding {sum(len(paths) for paths in video_keys_paths_dict.values())} video file(s).")
|
||||
|
||||
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
|
||||
worker_args = [(path, encoder, encoder_threads) for video_key, encoder in video_keys_encoders_dict.items() for path in video_keys_paths_dict[video_key]]
|
||||
if num_workers and num_workers > 1:
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as pool:
|
||||
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
|
||||
@@ -1957,10 +1978,14 @@ def reencode_dataset(
|
||||
for args in tqdm(worker_args, desc="Re-encoding videos"):
|
||||
_reencode_video_worker(args)
|
||||
|
||||
# Refresh video info in metadata for every video key.
|
||||
for vid_key in meta.video_keys:
|
||||
video_path = meta.root / meta.get_video_file_path(0, vid_key)
|
||||
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
# Refresh video info in metadata for every video key. For depth videos, preserve
|
||||
# ``is_depth_map`` and the depth quantization parameters.
|
||||
depth_preserve_keys = {"is_depth_map", *(f"video.{n}" for n in DEPTH_ENCODER_INFO_FIELD_NAMES)}
|
||||
for video_key, encoder in video_keys_encoders_dict.items():
|
||||
preserve_keys = depth_preserve_keys if video_key in meta.depth_keys else None
|
||||
meta.update_video_info(
|
||||
video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys
|
||||
)
|
||||
|
||||
write_info(meta.info, meta.root)
|
||||
logging.info("Dataset metadata updated.")
|
||||
|
||||
@@ -31,7 +31,12 @@ import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
|
||||
from .compute_stats import compute_episode_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
@@ -48,6 +53,7 @@ from .io_utils import (
|
||||
write_info,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_DEPTH_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
update_chunk_file_indices,
|
||||
@@ -67,17 +73,22 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
path_template = (
|
||||
DEFAULT_DEPTH_PATH
|
||||
if video_encoder is not None and isinstance(video_encoder, DepthEncoderConfig)
|
||||
else DEFAULT_IMAGE_PATH
|
||||
)
|
||||
fpath = path_template.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
camera_encoder=camera_encoder,
|
||||
video_encoder=video_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
@@ -97,6 +108,7 @@ class DatasetWriter:
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
depth_encoder: DepthEncoderConfig | None,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
@@ -110,6 +122,8 @@ class DatasetWriter:
|
||||
root: Local dataset root directory.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
|
||||
depth_encoder: Video encoder settings applied to all **depth** cameras.
|
||||
``None`` uses :func:`~lerobot.configs.depth_encoder_defaults`.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
@@ -121,6 +135,7 @@ class DatasetWriter:
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._depth_encoder = depth_encoder or depth_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
@@ -145,7 +160,8 @@ class DatasetWriter:
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
|
||||
fpath = path_template.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self._root / fpath
|
||||
@@ -195,6 +211,7 @@ class DatasetWriter:
|
||||
if frame_index == 0 and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.start_episode(
|
||||
video_keys=list(self._meta.video_keys),
|
||||
depth_video_keys=list(self._meta.depth_keys),
|
||||
temp_dir=self._root,
|
||||
)
|
||||
|
||||
@@ -282,10 +299,11 @@ class DatasetWriter:
|
||||
if use_streaming:
|
||||
streaming_results = self._streaming_encoder.finish_episode()
|
||||
for video_key in self._meta.video_keys:
|
||||
normalization_factor = 255.0 if video_key not in self._meta.depth_keys else 1.0
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
ep_stats[video_key] = {
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
@@ -300,7 +318,9 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder,
|
||||
self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -511,7 +531,12 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
|
||||
self._meta.update_video_info(
|
||||
video_key,
|
||||
video_encoder=self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -578,13 +603,14 @@ class DatasetWriter:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
"""Use ffmpeg to convert frames stored as png/tiff into mp4 videos."""
|
||||
is_depth = video_key in self._meta.depth_keys
|
||||
return _encode_video_worker(
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder,
|
||||
self._depth_encoder if is_depth else self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,257 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.configs.video import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_PIX_FMT,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
DEPTH_QMAX,
|
||||
)
|
||||
|
||||
from .pyav_utils import write_u16_plane
|
||||
|
||||
_MM_PER_METRE = 1000.0
|
||||
_UINT16_MAX = 65535
|
||||
|
||||
|
||||
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
"""Ensure ``log(depth_min + shift)`` is finite."""
|
||||
if depth_min + shift <= 0:
|
||||
raise ValueError(
|
||||
f"depth_min + shift must be positive for logarithmic quantization, "
|
||||
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
|
||||
)
|
||||
|
||||
|
||||
def _depth_input_to_float32_and_unit(
|
||||
depth: NDArray[np.integer] | NDArray[np.floating],
|
||||
input_unit: Literal["auto", "m", "mm"],
|
||||
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
|
||||
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
|
||||
resolved_unit = (
|
||||
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
|
||||
)
|
||||
return depth.astype(np.float32, order="K"), resolved_unit
|
||||
|
||||
|
||||
def quantize_depth(
|
||||
depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
video_backend: str | None = "pyav",
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> NDArray[np.uint16] | av.VideoFrame:
|
||||
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
|
||||
|
||||
Depth maps are packed into 12-bit integer frames so they fit in standard
|
||||
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
|
||||
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
|
||||
Logarithmic quantization is the default because it allocates more quanta
|
||||
to near-range depth, which matches the (1/depth) error profile of typical
|
||||
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
|
||||
|
||||
**Input units**:
|
||||
|
||||
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
|
||||
- ``input_unit="mm"``: interpret input values as millimetres.
|
||||
- ``input_unit="m"``: interpret input values as metres.
|
||||
|
||||
Quantization math runs in the **resolved input unit**.
|
||||
|
||||
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
|
||||
|
||||
Args:
|
||||
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
|
||||
depth_min: Depth (metres) at quantum ``0``.
|
||||
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
|
||||
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
|
||||
use_log: If ``True`` (default), quantize in log space.
|
||||
video_backend: Video backend to use for encoding. Defaults to "pyav".
|
||||
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
|
||||
|
||||
Returns:
|
||||
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
|
||||
``[0, DEPTH_QMAX]``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if input_unit not in ("auto", "m", "mm"):
|
||||
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
|
||||
|
||||
if isinstance(depth, torch.Tensor):
|
||||
depth = depth.detach().cpu().numpy()
|
||||
|
||||
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
|
||||
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
|
||||
depth = depth.squeeze()
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
|
||||
# Convert depth_min, depth_max, and shift to the resolved input unit.
|
||||
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
# Normalization and quantization is performed in the resolved input unit.
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_u + shift_u))
|
||||
log_max = math.log(float(depth_max_u + shift_u))
|
||||
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
|
||||
else:
|
||||
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
|
||||
|
||||
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
|
||||
|
||||
if video_backend == "pyav":
|
||||
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
|
||||
write_u16_plane(frame.planes[0], quantized)
|
||||
return frame
|
||||
else:
|
||||
return quantized
|
||||
|
||||
|
||||
def dequantize_depth(
|
||||
quantized: NDArray[np.uint16] | av.VideoFrame | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
output_tensor: bool = True,
|
||||
output_channel_last: bool = False,
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
|
||||
"""Inverse of :func:`quantize_depth`.
|
||||
|
||||
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
|
||||
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
|
||||
the requested output unit. Tuning arguments **must match** :func:`quantize_depth`.
|
||||
|
||||
Accepted input layouts :
|
||||
|
||||
- ``(H, W, 1)`` or ``(H, W)`` — single frame with channel-last.
|
||||
- ``(..., 1, H, W)`` — batched frames with channel-first.
|
||||
- ``(..., H, W, 1)`` — batched frames with channel-last.
|
||||
Output layout is determined by ``output_channel_last``.
|
||||
|
||||
Args:
|
||||
quantized: 12-bit codes in ``[0, DEPTH_QMAX]``. ``np.ndarray``,
|
||||
``av.VideoFrame``, or ``torch.Tensor`` (any integer or float dtype).
|
||||
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
|
||||
pix_fmt: Pixel format used to extract the plane from an ``av.VideoFrame``.
|
||||
output_unit: ``"mm"`` returns ``uint16`` millimetres (rint, clip
|
||||
``[0, 65535]``) when returning a numpy array, or ``float32`` mm when
|
||||
``output_tensor=True``. ``"m"`` returns ``float32`` metres in
|
||||
``[depth_min, depth_max]``.
|
||||
output_tensor: If True, return a ``torch.Tensor`` instead of a numpy array.
|
||||
|
||||
Returns:
|
||||
Depth map in the requested unit and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
|
||||
if isinstance(quantized, av.VideoFrame):
|
||||
quantized = quantized.to_ndarray(format=pix_fmt)
|
||||
|
||||
# Compute the scale and offset first.
|
||||
depth_min_m = float(depth_min)
|
||||
depth_max_m = float(depth_max)
|
||||
shift_m = float(shift)
|
||||
if use_log:
|
||||
log_min = math.log(depth_min_m + shift_m)
|
||||
log_max = math.log(depth_max_m + shift_m)
|
||||
scale = (log_max - log_min) / DEPTH_QMAX
|
||||
offset = log_min
|
||||
else:
|
||||
scale = (depth_max_m - depth_min_m) / DEPTH_QMAX
|
||||
offset = depth_min_m
|
||||
|
||||
# ── Torch path: stay on the input device, single fp32 allocation. ────────
|
||||
if isinstance(quantized, torch.Tensor):
|
||||
|
||||
if quantized.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
quantized = quantized.squeeze(-3) if quantized.shape[-3] == 1 else quantized.squeeze(-1)
|
||||
|
||||
# Single allocation we own; everything else is in-place.
|
||||
buf = quantized.to(dtype=torch.float32, copy=True)
|
||||
buf.mul_(scale).add_(offset)
|
||||
if use_log:
|
||||
buf.exp_().sub_(shift_m)
|
||||
buf.clamp_(depth_min_m, depth_max_m)
|
||||
buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3)
|
||||
|
||||
if output_unit == "m":
|
||||
return buf if output_tensor else buf.cpu().numpy()
|
||||
|
||||
# mm path: round + clamp in float32, skipping the uint16 round-trip
|
||||
# when returning a tensor (torch.uint16 is poorly supported).
|
||||
buf.mul_(_MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
|
||||
if output_tensor:
|
||||
return buf
|
||||
return buf.cpu().numpy().astype(np.uint16, copy=False)
|
||||
|
||||
# ── NumPy path: single fp32 allocation, ``out=`` for in-place math. ─────
|
||||
arr = np.asarray(quantized)
|
||||
if arr.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
arr = np.squeeze(arr, axis=-3) if arr.shape[-3] == 1 else np.squeeze(arr, axis=-1)
|
||||
|
||||
buf = np.empty(arr.shape, dtype=np.float32)
|
||||
np.multiply(arr, scale, out=buf)
|
||||
np.add(buf, offset, out=buf)
|
||||
if use_log:
|
||||
np.exp(buf, out=buf)
|
||||
np.subtract(buf, shift_m, out=buf)
|
||||
np.clip(buf, depth_min_m, depth_max_m, out=buf)
|
||||
buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3)
|
||||
|
||||
if output_unit == "m":
|
||||
return torch.from_numpy(buf) if output_tensor else buf
|
||||
|
||||
np.multiply(buf, _MM_PER_METRE, out=buf)
|
||||
np.rint(buf, out=buf)
|
||||
np.clip(buf, 0.0, _UINT16_MAX, out=buf)
|
||||
if output_tensor:
|
||||
# torch.uint16 support is very limited; return float32 millimetres.
|
||||
return torch.from_numpy(buf)
|
||||
return buf.astype(np.uint16, copy=False)
|
||||
@@ -336,7 +336,7 @@ def validate_feature_image_or_video(
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape (C, H, W).
|
||||
expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
|
||||
value: The image data to validate.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -42,10 +42,41 @@ def safe_stop_image_writer(func):
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
|
||||
|
||||
Behaviour by shape:
|
||||
|
||||
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
|
||||
The native dtype is preserved using the matching PIL mode
|
||||
(``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
|
||||
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
|
||||
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
|
||||
(existing behaviour, gated by ``range_check``).
|
||||
|
||||
Other shapes / channel counts raise ``NotImplementedError`` or
|
||||
``ValueError``.
|
||||
"""
|
||||
# TODO(CarolinePascal): 4 dimensions RGB-D images
|
||||
if image_array.ndim not in (2, 3):
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image.")
|
||||
|
||||
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
|
||||
# caller emits (H, W), (1, H, W), or (H, W, 1).
|
||||
if image_array.ndim == 3:
|
||||
if image_array.shape[0] == 1:
|
||||
image_array = image_array[0]
|
||||
elif image_array.shape[-1] == 1:
|
||||
image_array = image_array[..., 0]
|
||||
|
||||
if image_array.ndim == 2:
|
||||
if image_array.dtype not in [np.uint16, np.float32]:
|
||||
raise ValueError(
|
||||
f"Unsupported single-channel image dtype: {image_array.dtype}. "
|
||||
f"Supported dtypes: {sorted(str(d) for d in [np.uint16, np.float32])}."
|
||||
)
|
||||
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
|
||||
|
||||
# 3D path: must be RGB (3 channels), channels-first or channels-last.
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
@@ -71,13 +102,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
|
||||
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
|
||||
|
||||
PNG uses ``compress_level`` (0-9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
|
||||
"""
|
||||
suffix = Path(fpath).suffix.lower()
|
||||
if suffix == ".png":
|
||||
return {"compress_level": compress_level}
|
||||
if suffix in (".tif", ".tiff"):
|
||||
return {"compression": "raw"}
|
||||
return {}
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation.
|
||||
the save operation. The output format is inferred from the *fpath*
|
||||
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
|
||||
→ lossless raw depth maps (TIFF).
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
@@ -101,7 +147,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
img.save(fpath, **save_kwargs_for_path(fpath, compress_level))
|
||||
except Exception as e:
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
@@ -60,6 +60,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -186,6 +187,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
|
||||
is used by the writer.
|
||||
depth_encoder (DepthEncoderConfig | None, optional): Video encoder settings for depth cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults`
|
||||
is used by the writer.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
@@ -273,6 +277,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps,
|
||||
camera_encoder,
|
||||
depth_encoder,
|
||||
encoder_queue_maxsize,
|
||||
encoder_threads,
|
||||
)
|
||||
@@ -280,6 +285,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -322,12 +328,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
depth_encoder: DepthEncoderConfig | None,
|
||||
encoder_queue_maxsize: int,
|
||||
encoder_threads: int | None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
@@ -645,6 +653,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -677,6 +686,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
@@ -720,12 +731,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -749,6 +761,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
@@ -778,6 +791,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch-encoding videos.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.depth.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
@@ -824,12 +839,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
obj.meta.fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
@@ -24,6 +24,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,6 +32,22 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
|
||||
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
|
||||
height, width = src.shape
|
||||
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
|
||||
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
|
||||
if fill_value is not None:
|
||||
dst.fill(fill_value)
|
||||
dst[:, :width] = src
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_pix_fmt_channels(pix_fmt: str) -> int:
|
||||
"""Return the number of components (channels) for *pix_fmt*."""
|
||||
return len(av.VideoFormat(pix_fmt).components)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
@@ -142,6 +159,16 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
|
||||
"""Ensure *pix_fmt* can carry at least *channels* components."""
|
||||
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
|
||||
if pix_fmt_channels < channels:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
|
||||
f"but the source data has {channels} channel(s)."
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
@@ -156,12 +183,18 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
_check_option_value(vcodec, key, value, supported_options[key])
|
||||
|
||||
|
||||
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
|
||||
def check_video_encoder_parameters_pyav(
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, Any],
|
||||
channels: int | None = None,
|
||||
) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
@@ -171,4 +204,6 @@ def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options
|
||||
if not options:
|
||||
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
|
||||
_check_pixel_format(vcodec, pix_fmt)
|
||||
if channels is not None:
|
||||
_check_pix_fmt_channels(pix_fmt, channels)
|
||||
_check_codec_options(vcodec, codec_options)
|
||||
|
||||
@@ -92,6 +92,7 @@ DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
|
||||
@@ -39,11 +39,16 @@ from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
from .depth_utils import quantize_depth
|
||||
from .pyav_utils import get_pix_fmt_channels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -53,6 +58,7 @@ def decode_video_frames(
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
@@ -72,6 +78,11 @@ def decode_video_frames(
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend != "pyav" and is_depth:
|
||||
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
|
||||
# We do not actually return uint8 here, but we avoid the 255 normalization step.
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=True, is_depth=True)
|
||||
|
||||
if backend is None:
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
@@ -91,6 +102,7 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video using PyAV.
|
||||
|
||||
@@ -140,9 +152,13 @@ def decode_video_frames_pyav(
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
arr = frame.to_ndarray(format="rgb24") # H, W, 3
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
if is_depth:
|
||||
arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
|
||||
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
|
||||
else:
|
||||
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
@@ -406,17 +422,17 @@ def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
if video_encoder is None:
|
||||
video_encoder = camera_encoder_defaults()
|
||||
vcodec = video_encoder.vcodec
|
||||
pix_fmt = video_encoder.pix_fmt
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -428,7 +444,9 @@ def encode_video_frames(
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get input frames
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||
suffix = ".png" if not is_depth else ".tiff"
|
||||
template = "frame-" + ("[0-9]" * 6) + suffix
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
@@ -438,7 +456,7 @@ def encode_video_frames(
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -455,8 +473,19 @@ def encode_video_frames(
|
||||
# Loop through input frames and encode them
|
||||
for input_data in input_list:
|
||||
with Image.open(input_data) as input_image:
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
if is_depth:
|
||||
input_frame = quantize_depth(
|
||||
np.array(input_image),
|
||||
depth_min=video_encoder.depth_min,
|
||||
depth_max=video_encoder.depth_max,
|
||||
shift=video_encoder.shift,
|
||||
use_log=video_encoder.use_log,
|
||||
pix_fmt=video_encoder.pix_fmt,
|
||||
video_backend="pyav"
|
||||
)
|
||||
else:
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
@@ -477,7 +506,7 @@ def encode_video_frames(
|
||||
def reencode_video(
|
||||
input_video_path: Path | str,
|
||||
output_video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
@@ -487,13 +516,13 @@ def reencode_video(
|
||||
Args:
|
||||
input_video_path: Existing video file to read.
|
||||
output_video_path: Path for the re-encoded file.
|
||||
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
video_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||
"""
|
||||
|
||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
video_encoder = video_encoder or camera_encoder_defaults()
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
@@ -503,9 +532,9 @@ def reencode_video(
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = video_encoder.vcodec
|
||||
pix_fmt = video_encoder.pix_fmt
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
@@ -676,22 +705,21 @@ class _CameraEncoderThread(threading.Thread):
|
||||
self,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, str],
|
||||
video_encoder: VideoEncoderConfig,
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.codec_options = codec_options
|
||||
self.video_encoder = video_encoder
|
||||
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
@@ -716,12 +744,12 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC uint8 numpy array
|
||||
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8:
|
||||
if not self.is_depth and frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
@@ -729,15 +757,29 @@ class _CameraEncoderThread(threading.Thread):
|
||||
height, width = frame_data.shape[:2]
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream = container.add_stream(
|
||||
self.video_encoder.vcodec,
|
||||
self.fps,
|
||||
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
|
||||
)
|
||||
output_stream.pix_fmt = self.video_encoder.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
if not self.is_depth:
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
else:
|
||||
video_frame = quantize_depth(
|
||||
frame_data,
|
||||
depth_min=self.video_encoder.depth_min,
|
||||
depth_max=self.video_encoder.depth_max,
|
||||
shift=self.video_encoder.shift,
|
||||
use_log=self.video_encoder.use_log,
|
||||
video_backend=self.video_encoder.video_backend,
|
||||
)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
@@ -796,6 +838,7 @@ class StreamingVideoEncoder:
|
||||
self,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
@@ -811,6 +854,7 @@ class StreamingVideoEncoder:
|
||||
"""
|
||||
self.fps = fps
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._depth_encoder = depth_encoder or depth_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.queue_maxsize = queue_maxsize
|
||||
|
||||
@@ -823,18 +867,25 @@ class StreamingVideoEncoder:
|
||||
self._episode_active = False
|
||||
self._closed = False
|
||||
|
||||
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
||||
def start_episode(
|
||||
self, video_keys: list[str], temp_dir: Path, depth_video_keys: list[str] | None = None
|
||||
) -> None:
|
||||
"""Start encoder threads for a new episode.
|
||||
|
||||
Args:
|
||||
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
||||
temp_dir: Base directory for temporary MP4 files
|
||||
depth_video_keys: List of video or image feature keys that carry depth maps (e.g.
|
||||
["observation.images.laptop_depth"]). Defaults to ``[]`` (no depth keys).
|
||||
"""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
self._dropped_frames.clear()
|
||||
|
||||
if depth_video_keys is None:
|
||||
depth_video_keys = []
|
||||
|
||||
for video_key in video_keys:
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
@@ -843,17 +894,15 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
vcodec = self._camera_encoder.vcodec
|
||||
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
|
||||
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=self._camera_encoder.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
video_encoder=encoder,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self._encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -1060,13 +1109,13 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
camera_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
@@ -1086,13 +1135,10 @@ def get_video_info(
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
|
||||
video_info["video.channels"] = pixel_channels
|
||||
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
@@ -1101,27 +1147,18 @@ def get_video_info(
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if camera_encoder is not None:
|
||||
for field_name, field_value in asdict(camera_encoder).items():
|
||||
if video_encoder is not None:
|
||||
for field_name, field_value in asdict(video_encoder).items():
|
||||
# vcodec is already populated from the video stream
|
||||
if field_name == "vcodec":
|
||||
continue
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
@@ -1182,7 +1219,8 @@ class VideoEncodingManager:
|
||||
img_dir = self.dataset.root / "images"
|
||||
if img_dir.exists():
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
tiff_files = list(img_dir.rglob("*.tiff"))
|
||||
if len(png_files) == 0 and len(tiff_files) == 0:
|
||||
shutil.rmtree(img_dir)
|
||||
logger.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
|
||||
@@ -126,7 +126,8 @@ def prepare_observation_for_inference(
|
||||
for name in observation:
|
||||
observation[name] = torch.from_numpy(observation[name])
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
if observation[name].dtype == torch.uint8:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
@@ -68,9 +68,12 @@ class SOFollower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
features: dict[str, tuple] = {}
|
||||
for cam in self.cameras:
|
||||
features[cam] = (self.cameras[cam].height, self.cameras[cam].width, 3)
|
||||
if getattr(self.cameras[cam], "use_depth", False):
|
||||
features[f"{cam}_depth"] = (self.cameras[cam].height, self.cameras[cam].width, 1)
|
||||
return features
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -190,6 +193,12 @@ class SOFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
if getattr(cam, "use_depth", False):
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -333,6 +333,7 @@ def build_rollout_context(
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
@@ -368,6 +369,7 @@ def build_rollout_context(
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
|
||||
@@ -403,6 +403,7 @@ def record(
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
@@ -432,6 +433,7 @@ def record(
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder=cfg.dataset.camera_encoder,
|
||||
depth_encoder=cfg.dataset.depth_encoder,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
|
||||
@@ -69,6 +69,7 @@ def hw_to_dataset_features(
|
||||
for key, ftype in hw_features.items()
|
||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||
}
|
||||
# TODO(CarolinePascal): we should not rely on the shape to determine if a feature is a camera !
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == ACTION:
|
||||
@@ -86,11 +87,19 @@ def hw_to_dataset_features(
|
||||
}
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
dtype = "video" if use_video else "image"
|
||||
if len(shape) == 3 and shape[2] in (1, 3):
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": dtype,
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"is_depth_map": shape[2] == 1},
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Camera feature '{key}' has shape {shape}. "
|
||||
f"Expected a 3-tuple (H, W, C), e.g. (480, 640, 3) for RGB or (480, 640, 1) for depth."
|
||||
)
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
@@ -149,11 +158,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
else:
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
|
||||
@@ -107,8 +107,15 @@ def log_rerun_data(
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
else:
|
||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||
rr.log(key, entity=img_entity, static=True)
|
||||
if arr.shape[-1] == 1:
|
||||
img_entity = (
|
||||
rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis).compress()
|
||||
if compress_images
|
||||
else rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
|
||||
)
|
||||
else:
|
||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||
rr.log(key, entity=img_entity)
|
||||
|
||||
if action:
|
||||
for k, v in action.items():
|
||||
|
||||
@@ -29,8 +29,12 @@ from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.feature_utils import features_equal_for_merge
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
from tests.fixtures.constants import (
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_DEPTH_CAMERA_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||
"""Test that total number of episodes and frames are correctly aggregated."""
|
||||
@@ -191,6 +195,28 @@ def assert_dataset_iteration_works(aggr_ds):
|
||||
pass
|
||||
|
||||
|
||||
def assert_depth_keys_preserved(aggr_ds, ds_0, ds_1):
|
||||
"""Test that depth keys are correctly preserved after aggregation.
|
||||
|
||||
Ensures that the ``is_depth_map`` marker on visual features survives
|
||||
aggregation, so that downstream consumers (e.g. the dataset reader's
|
||||
depth decoding path) keep working on the merged dataset.
|
||||
"""
|
||||
expected_depth_keys = set(ds_0.meta.depth_keys)
|
||||
assert expected_depth_keys == set(ds_1.meta.depth_keys), (
|
||||
"Source datasets disagree on depth_keys; test setup is inconsistent"
|
||||
)
|
||||
actual_depth_keys = set(aggr_ds.meta.depth_keys)
|
||||
assert actual_depth_keys == expected_depth_keys, (
|
||||
f"Expected depth_keys {expected_depth_keys}, got {actual_depth_keys}"
|
||||
)
|
||||
for key in expected_depth_keys:
|
||||
info = aggr_ds.meta.info.features[key].get("info") or {}
|
||||
assert info.get("is_depth_map") is True, (
|
||||
f"Depth marker lost on feature {key!r} after aggregation"
|
||||
)
|
||||
|
||||
|
||||
def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
"""Test that all video timestamps are within valid bounds for their respective video files.
|
||||
|
||||
@@ -240,7 +266,11 @@ def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test basic aggregation functionality with standard parameters."""
|
||||
"""Test basic aggregation functionality with standard parameters.
|
||||
|
||||
Source datasets include both RGB and depth video features so the same
|
||||
aggregation flow is exercised on the ``is_depth_map`` branch.
|
||||
"""
|
||||
ds_0_num_frames = 400
|
||||
ds_1_num_frames = 800
|
||||
ds_0_num_episodes = 10
|
||||
@@ -252,14 +282,21 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
# Confirm depth was actually wired into the source datasets so the
|
||||
# rest of the assertions exercise the depth aggregation path.
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
@@ -286,6 +323,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
@@ -357,7 +395,11 @@ def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
|
||||
|
||||
|
||||
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding."""
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding.
|
||||
|
||||
Depth video features are included to verify that file rotation/concat
|
||||
correctly handles depth-marked features alongside regular RGB ones.
|
||||
"""
|
||||
ds_0_num_episodes = ds_1_num_episodes = 10
|
||||
ds_0_num_frames = ds_1_num_frames = 400
|
||||
|
||||
@@ -366,14 +408,19 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "small_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
# Use the new configurable parameters to force file rotation
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
@@ -404,6 +451,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
# Check that multiple files were actually created due to small size limits
|
||||
@@ -423,7 +471,8 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
"""Regression test for video timestamp bug when merging datasets.
|
||||
|
||||
This test specifically checks that video timestamps are correctly calculated
|
||||
and accumulated when merging multiple datasets.
|
||||
and accumulated when merging multiple datasets. Depth video features are
|
||||
included so depth timestamps are also covered by the regression.
|
||||
"""
|
||||
datasets = []
|
||||
for i in range(3):
|
||||
@@ -432,9 +481,13 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
|
||||
total_episodes=2,
|
||||
total_frames=100,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
datasets.append(ds)
|
||||
|
||||
for i, ds in enumerate(datasets):
|
||||
assert len(ds.meta.depth_keys) > 0, f"Dataset {i} should expose at least one depth key"
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds.repo_id for ds in datasets],
|
||||
roots=[ds.root for ds in datasets],
|
||||
@@ -451,12 +504,21 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
|
||||
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
# Depth keys must survive the merge for the regression to cover the
|
||||
# ``is_depth_map`` decoding branch.
|
||||
assert set(aggr_ds.meta.depth_keys) == set(datasets[0].meta.depth_keys)
|
||||
|
||||
depth_keys = set(aggr_ds.meta.depth_keys)
|
||||
for i in range(len(aggr_ds)):
|
||||
item = aggr_ds[i]
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert key in item, f"Video key {key} missing from item {i}"
|
||||
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||
# Depth frames are single-channel (1, H, W) after dequantization;
|
||||
# standard RGB frames keep the 3-channel layout.
|
||||
expected_channels = 1 if key in depth_keys else 3
|
||||
assert item[key].shape[0] == expected_channels, (
|
||||
f"Expected {expected_channels} channels for video key {key}, got {item[key].shape}"
|
||||
)
|
||||
|
||||
|
||||
def assert_image_schema_preserved(aggr_ds):
|
||||
@@ -538,25 +600,31 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
ds_0_num_episodes = 2
|
||||
ds_1_num_episodes = 3
|
||||
|
||||
# Create two image-based datasets (use_videos=False)
|
||||
# Create two image-based datasets (use_videos=False) with a mix of RGB
|
||||
# and depth-marked cameras so the depth path is exercised in image mode.
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
use_videos=False,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "image_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_image_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
use_videos=False, # Image-based dataset
|
||||
use_videos=False,
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
)
|
||||
|
||||
# Verify source datasets have image keys
|
||||
assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys"
|
||||
assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys"
|
||||
# And that the depth marker actually made it onto an image feature.
|
||||
assert len(ds_0.meta.depth_keys) > 0, "ds_0 should expose at least one depth key"
|
||||
assert len(ds_1.meta.depth_keys) > 0, "ds_1 should expose at least one depth key"
|
||||
|
||||
# Aggregate the datasets
|
||||
aggregate_datasets(
|
||||
@@ -591,6 +659,7 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
# Image-specific assertions
|
||||
assert_image_schema_preserved(aggr_ds)
|
||||
assert_image_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_depth_keys_preserved(aggr_ds, ds_0, ds_1)
|
||||
|
||||
# Verify images can be accessed and have correct shape
|
||||
sample_item = aggr_ds[0]
|
||||
|
||||
@@ -59,11 +59,13 @@ def _make_dummy_stats(features: dict) -> dict:
|
||||
stats = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] in ("image", "video"):
|
||||
channels = ft["shape"][-1]
|
||||
stat_shape = (channels, 1, 1)
|
||||
stats[key] = {
|
||||
"max": np.ones((3, 1, 1), dtype=np.float32),
|
||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
|
||||
"min": np.zeros((3, 1, 1), dtype=np.float32),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
|
||||
"max": np.ones(stat_shape, dtype=np.float32),
|
||||
"mean": np.full(stat_shape, 0.5, dtype=np.float32),
|
||||
"min": np.zeros(stat_shape, dtype=np.float32),
|
||||
"std": np.full(stat_shape, 0.25, dtype=np.float32),
|
||||
"count": np.array([5]),
|
||||
}
|
||||
elif ft["dtype"] in ("float32", "float64", "int64"):
|
||||
@@ -142,6 +144,45 @@ def test_create_without_videos_has_no_video_path(tmp_path):
|
||||
assert meta.video_keys == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("marker_field", "marker_key"),
|
||||
[
|
||||
("info", "is_depth_map"),
|
||||
("info", "video.is_depth_map"),
|
||||
("video_info", "video.is_depth_map"),
|
||||
],
|
||||
ids=["info.is_depth_map", "info.video.is_depth_map_legacy", "video_info.video.is_depth_map_legacy"],
|
||||
)
|
||||
def test_depth_keys_property_filters_by_marker(tmp_path, marker_field, marker_key):
|
||||
"""``depth_keys`` recognises the canonical and the two legacy marker variants."""
|
||||
depth_feature = {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
marker_field: {marker_key: True},
|
||||
}
|
||||
features = {
|
||||
**VIDEO_FEATURES,
|
||||
"observation.images.laptop_depth": depth_feature,
|
||||
}
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/depth_keys",
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=tmp_path / f"depth_keys_{marker_field}_{marker_key.replace('.', '_')}",
|
||||
)
|
||||
|
||||
assert set(meta.video_keys) == {"observation.images.laptop", "observation.images.laptop_depth"}
|
||||
assert meta.depth_keys == ["observation.images.laptop_depth"]
|
||||
|
||||
|
||||
def test_depth_keys_empty_when_no_marker(tmp_path):
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/no_depth", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=tmp_path / "no_depth"
|
||||
)
|
||||
assert meta.depth_keys == []
|
||||
|
||||
|
||||
def test_create_raises_on_existing_directory(tmp_path):
|
||||
"""create() raises if root directory already exists."""
|
||||
root = tmp_path / "existing"
|
||||
|
||||
@@ -37,7 +37,8 @@ from lerobot.datasets.dataset_tools import (
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.io_utils import load_info
|
||||
from tests.datasets.test_video_encoding import _add_frames, require_h264, require_libsvtav1
|
||||
from tests.datasets.test_video_encoding import require_h264, require_libsvtav1
|
||||
from tests.fixtures.dataset_factories import add_frames
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1350,9 +1351,9 @@ def test_reencode_dataset_multi_key_multiprocessing(
|
||||
camera_encoder=initial_cfg,
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
_add_frames(dataset, num_frames=4)
|
||||
add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ def _make_frame(features: dict, task: str = "Dummy task") -> dict:
|
||||
# ── Existing encode_video_worker tests ───────────────────────────────
|
||||
|
||||
|
||||
def test_encode_video_worker_forwards_camera_encoder(tmp_path):
|
||||
"""_encode_video_worker forwards camera_encoder to encode_video_frames."""
|
||||
def test_encode_video_worker_forwards_video_encoder(tmp_path):
|
||||
"""_encode_video_worker forwards video_encoder to encode_video_frames."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
@@ -74,16 +74,16 @@ def test_encode_video_worker_forwards_camera_encoder(tmp_path):
|
||||
0,
|
||||
tmp_path,
|
||||
fps=30,
|
||||
camera_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
|
||||
video_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
|
||||
encoder_threads=4,
|
||||
)
|
||||
|
||||
assert captured_kwargs["camera_encoder"].vcodec == "h264"
|
||||
assert captured_kwargs["video_encoder"].vcodec == "h264"
|
||||
assert captured_kwargs["encoder_threads"] == 4
|
||||
|
||||
|
||||
def test_encode_video_worker_default_camera_encoder(tmp_path):
|
||||
"""_encode_video_worker passes None camera_encoder which encode_video_frames defaults."""
|
||||
def test_encode_video_worker_default_video_encoder(tmp_path):
|
||||
"""_encode_video_worker passes None video_encoder which encode_video_frames defaults."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
@@ -100,7 +100,7 @@ def test_encode_video_worker_default_camera_encoder(tmp_path):
|
||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||
_encode_video_worker(video_key, 0, tmp_path, fps=30)
|
||||
|
||||
assert captured_kwargs["camera_encoder"] is None
|
||||
assert captured_kwargs["video_encoder"] is None
|
||||
assert captured_kwargs["encoder_threads"] is None
|
||||
|
||||
|
||||
|
||||
@@ -1516,10 +1516,15 @@ def test_valid_video_codecs_constant():
|
||||
assert "h264" in VALID_VIDEO_CODECS
|
||||
assert "hevc" in VALID_VIDEO_CODECS
|
||||
assert "libsvtav1" in VALID_VIDEO_CODECS
|
||||
assert "ffv1" in VALID_VIDEO_CODECS
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 10
|
||||
assert "h264_vaapi" in VALID_VIDEO_CODECS
|
||||
assert "h264_qsv" in VALID_VIDEO_CODECS
|
||||
assert "hevc_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "hevc_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 11
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
@@ -0,0 +1,304 @@
|
||||
"""Tests for the depth-integration feature.
|
||||
|
||||
Covers quantization/dequantization round-trips (depth_utils), image writer
|
||||
depth support (image_writer), hardware→dataset feature routing
|
||||
(feature_utils), video info helpers (video_utils / configs.video), and
|
||||
feature-to-file-format routing through the dataset writer.
|
||||
|
||||
Depth metadata detection on ``LeRobotDatasetMetadata.depth_keys`` (canonical
|
||||
and legacy marker variants) lives in ``test_dataset_metadata.py``.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs import DepthEncoderConfig
|
||||
from lerobot.configs.video import DEPTH_QMAX, VALID_VIDEO_CODECS
|
||||
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
|
||||
from lerobot.datasets.image_writer import (
|
||||
image_array_to_pil_image,
|
||||
save_kwargs_for_path,
|
||||
write_image,
|
||||
)
|
||||
from lerobot.datasets.pyav_utils import get_pix_fmt_channels
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_DEPTH_CAMERA_FEATURES,
|
||||
DUMMY_MOTOR_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
)
|
||||
|
||||
H, W = 48, 64
|
||||
DEPTH_MIN = 0.01
|
||||
DEPTH_MAX = 10.0
|
||||
|
||||
|
||||
# ── 1. Quantize / Dequantize round-trips ────────────────────────────
|
||||
|
||||
|
||||
class TestQuantizeDequantize:
|
||||
"""Core numerical tests for depth_utils.quantize_depth / dequantize_depth."""
|
||||
|
||||
def _make_depth_metres(self) -> np.ndarray:
|
||||
"""Linearly-spaced float32 depth in metres covering the default range."""
|
||||
return np.linspace(DEPTH_MIN, DEPTH_MAX, H * W, dtype=np.float32).reshape(H, W)
|
||||
|
||||
def test_roundtrip_linear_metres(self):
|
||||
depth = self._make_depth_metres()
|
||||
quantized = quantize_depth(depth, use_log=False, video_backend=None)
|
||||
recovered = dequantize_depth(quantized, use_log=False, output_unit="m")
|
||||
|
||||
assert recovered.shape == (H, W, 1), f"Expected (H,W,1), got {recovered.shape}"
|
||||
assert recovered.dtype == np.float32
|
||||
tol = (DEPTH_MAX - DEPTH_MIN) / DEPTH_QMAX
|
||||
np.testing.assert_allclose(recovered[..., 0], depth, atol=tol + 1e-6)
|
||||
|
||||
def test_roundtrip_log_metres(self):
|
||||
depth = self._make_depth_metres()
|
||||
quantized = quantize_depth(depth, use_log=True, video_backend=None)
|
||||
recovered = dequantize_depth(quantized, use_log=True, output_unit="m")
|
||||
|
||||
assert recovered.shape == (H, W, 1)
|
||||
near = depth < 1.0
|
||||
far = depth > 8.0
|
||||
err_near = np.abs(recovered[..., 0][near] - depth[near])
|
||||
err_far = np.abs(recovered[..., 0][far] - depth[far])
|
||||
assert err_near.mean() < err_far.mean(), "Log quant should be more precise at close range"
|
||||
|
||||
def test_roundtrip_mm_uint16_input(self):
|
||||
depth_mm = np.linspace(10, 10000, H * W, dtype=np.float64).reshape(H, W).astype(np.uint16)
|
||||
quantized = quantize_depth(depth_mm, use_log=False, video_backend=None, input_unit="mm")
|
||||
recovered = dequantize_depth(quantized, use_log=False, output_unit="mm")
|
||||
|
||||
assert recovered.dtype == np.uint16
|
||||
tol_mm = (DEPTH_MAX - DEPTH_MIN) * 1000.0 / DEPTH_QMAX
|
||||
np.testing.assert_allclose(
|
||||
recovered[..., 0].astype(np.float64), depth_mm.astype(np.float64), atol=tol_mm + 1.0
|
||||
)
|
||||
|
||||
def test_quantize_clamps_out_of_range(self):
|
||||
depth = np.array([[0.001, 99.0]], dtype=np.float32)
|
||||
quantized = quantize_depth(depth, use_log=False, video_backend=None)
|
||||
assert quantized[0, 0] == 0
|
||||
assert quantized[0, 1] == DEPTH_QMAX
|
||||
|
||||
def test_quantize_accepts_torch_tensor(self):
|
||||
t = torch.rand(H, W, dtype=torch.float32) * (DEPTH_MAX - DEPTH_MIN) + DEPTH_MIN
|
||||
result = quantize_depth(t, video_backend=None)
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.dtype == np.uint16
|
||||
|
||||
def test_quantize_squeezes_channel_dim(self):
|
||||
depth = self._make_depth_metres()
|
||||
for shape in [(H, W, 1), (1, H, W)]:
|
||||
reshaped = depth.reshape(shape)
|
||||
quantized = quantize_depth(reshaped, video_backend=None)
|
||||
assert quantized.ndim == 2, f"Input shape {shape} should be squeezed to 2D"
|
||||
|
||||
def test_quantize_returns_pyav_frame(self):
|
||||
depth = self._make_depth_metres()
|
||||
result = quantize_depth(depth, video_backend="pyav")
|
||||
assert isinstance(result, av.VideoFrame)
|
||||
|
||||
def test_dequantize_output_tensor(self):
|
||||
quantized = np.full((H, W), DEPTH_QMAX // 2, dtype=np.uint16)
|
||||
result = dequantize_depth(quantized, output_unit="m", output_tensor=True)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.shape == (H, W, 1)
|
||||
|
||||
def test_invalid_log_params_raises(self):
|
||||
depth = np.ones((4, 4), dtype=np.float32)
|
||||
with pytest.raises(ValueError, match="depth_min \\+ shift must be positive"):
|
||||
quantize_depth(depth, depth_min=1.0, shift=-2.0, use_log=True, video_backend=None)
|
||||
|
||||
|
||||
# ── 2. Image writer depth support ───────────────────────────────────
|
||||
|
||||
|
||||
class TestImageWriterDepth:
|
||||
"""image_array_to_pil_image and write_image for single-channel depth maps."""
|
||||
|
||||
def test_pil_uint16_grayscale(self):
|
||||
arr = np.arange(H * W, dtype=np.uint16).reshape(H, W)
|
||||
img = image_array_to_pil_image(arr)
|
||||
assert isinstance(img, PIL.Image.Image)
|
||||
assert img.mode == "I;16"
|
||||
assert img.size == (W, H)
|
||||
|
||||
def test_pil_float32_grayscale(self):
|
||||
arr = np.random.rand(H, W).astype(np.float32)
|
||||
img = image_array_to_pil_image(arr)
|
||||
assert img.mode == "F"
|
||||
|
||||
def test_pil_squeeze_hwc1_and_1hw(self):
|
||||
arr_uint16 = np.zeros((H, W), dtype=np.uint16)
|
||||
for input_arr in [arr_uint16.reshape(H, W, 1), arr_uint16.reshape(1, H, W)]:
|
||||
img = image_array_to_pil_image(input_arr)
|
||||
assert img.size == (W, H)
|
||||
|
||||
def test_save_kwargs_png_vs_tiff(self):
|
||||
png_kw = save_kwargs_for_path(Path("frame.png"), compress_level=5)
|
||||
assert png_kw == {"compress_level": 5}
|
||||
|
||||
tiff_kw = save_kwargs_for_path(Path("frame.tiff"), compress_level=5)
|
||||
assert tiff_kw == {"compression": "raw"}
|
||||
|
||||
assert save_kwargs_for_path(Path("frame.jpg"), compress_level=5) == {}
|
||||
|
||||
def test_write_image_tiff_roundtrip(self, tmp_path):
|
||||
arr = np.arange(H * W, dtype=np.uint16).reshape(H, W)
|
||||
fpath = tmp_path / "depth.tiff"
|
||||
write_image(arr, fpath)
|
||||
|
||||
assert fpath.exists()
|
||||
with PIL.Image.open(fpath) as loaded:
|
||||
recovered = np.array(loaded)
|
||||
np.testing.assert_array_equal(recovered, arr)
|
||||
|
||||
|
||||
# ── 3. Feature routing ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestHwToDatasetFeaturesDepth:
|
||||
"""hw_to_dataset_features marks single-channel cameras as depth."""
|
||||
|
||||
def test_single_channel_cam_marked_depth(self):
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
|
||||
features = hw_to_dataset_features({"cam": (480, 640, 1)}, prefix="observation")
|
||||
ft = features["observation.images.cam"]
|
||||
assert ft["info"]["is_depth_map"] is True
|
||||
|
||||
def test_three_channel_cam_not_depth(self):
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
|
||||
features = hw_to_dataset_features({"cam": (480, 640, 3)}, prefix="observation")
|
||||
ft = features["observation.images.cam"]
|
||||
assert ft["info"]["is_depth_map"] is False
|
||||
|
||||
def test_invalid_channel_count_raises(self):
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
|
||||
with pytest.raises(ValueError, match="Expected a 3-tuple"):
|
||||
hw_to_dataset_features({"cam": (480, 640, 2)}, prefix="observation")
|
||||
|
||||
|
||||
# ── 4. Video info depth flag ────────────────────────────────────────
|
||||
|
||||
|
||||
class TestVideoInfoDepthFlag:
|
||||
"""Misc depth-related constants and helpers in video_utils / configs."""
|
||||
|
||||
def test_get_pix_fmt_channels_gray(self):
|
||||
assert get_pix_fmt_channels("gray12le") == 1
|
||||
assert get_pix_fmt_channels("gray8") == 1
|
||||
|
||||
def test_ffv1_in_valid_codecs(self):
|
||||
assert "ffv1" in VALID_VIDEO_CODECS
|
||||
|
||||
|
||||
# ── 5. Feature-to-file-format routing ───────────────────────────────
|
||||
|
||||
|
||||
def _build_mixed_features(dtype: str) -> dict:
|
||||
"""Build a feature dict with one RGB camera and one depth camera.
|
||||
|
||||
Uses shapes from ``DUMMY_CAMERA_FEATURES`` and ``DUMMY_DEPTH_CAMERA_FEATURES``
|
||||
defined in ``tests.fixtures.constants``.
|
||||
"""
|
||||
rgb_cam = next(iter(DUMMY_CAMERA_FEATURES.values()))
|
||||
depth_cam = next(iter(DUMMY_DEPTH_CAMERA_FEATURES.values()))
|
||||
return {
|
||||
"observation.images.rgb": {"dtype": dtype, **rgb_cam},
|
||||
"observation.images.depth": {"dtype": dtype, **depth_cam},
|
||||
**{k: {"dtype": v["dtype"], **v} for k, v in DUMMY_MOTOR_FEATURES.items()},
|
||||
}
|
||||
|
||||
|
||||
def _make_mixed_frame(features: dict) -> dict:
|
||||
"""Build a valid frame dict matching the given feature schema."""
|
||||
frame: dict = {"task": "test task"}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
if ft["dtype"] in ("image", "video"):
|
||||
channels = shape[-1]
|
||||
if channels == 1:
|
||||
frame[key] = np.random.randint(0, 4095, shape, dtype=np.uint16)
|
||||
else:
|
||||
frame[key] = np.random.randint(0, 255, shape, dtype=np.uint8)
|
||||
else:
|
||||
frame[key] = np.random.randn(*shape).astype(ft["dtype"])
|
||||
return frame
|
||||
|
||||
|
||||
class TestFeatureFileRouting:
|
||||
"""Verify that depth vs RGB features are routed to the correct file format."""
|
||||
|
||||
NUM_FRAMES = 5
|
||||
|
||||
def test_no_video_depth_tiff_rgb_png(self, tmp_path):
|
||||
"""Without video encoding: depth -> .tiff, RGB -> .png."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = _build_mixed_features(dtype="image")
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=tmp_path / "ds",
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
for _ in range(self.NUM_FRAMES):
|
||||
dataset.add_frame(_make_mixed_frame(features))
|
||||
|
||||
buf = dataset.writer.episode_buffer
|
||||
depth_paths = [Path(p) for p in buf["observation.images.depth"]]
|
||||
rgb_paths = [Path(p) for p in buf["observation.images.rgb"]]
|
||||
|
||||
assert all(p.suffix == ".tiff" for p in depth_paths), "Depth frames should be .tiff"
|
||||
assert all(p.suffix == ".png" for p in rgb_paths), "RGB frames should be .png"
|
||||
assert all(p.exists() for p in depth_paths), "Depth TIFF files should exist on disk"
|
||||
assert all(p.exists() for p in rgb_paths), "RGB PNG files should exist on disk"
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
def test_video_depth_uses_depth_encoder(self, tmp_path):
|
||||
"""With streaming video encoding: depth keys use DepthEncoderConfig, RGB keys do not."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = _build_mixed_features(dtype="video")
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=tmp_path / "ds",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
assert dataset.writer._streaming_encoder is not None
|
||||
encoder = dataset.writer._streaming_encoder
|
||||
|
||||
for _ in range(self.NUM_FRAMES):
|
||||
dataset.add_frame(_make_mixed_frame(features))
|
||||
|
||||
rgb_thread = encoder._threads["observation.images.rgb"]
|
||||
depth_thread = encoder._threads["observation.images.depth"]
|
||||
|
||||
assert not isinstance(rgb_thread.video_encoder, DepthEncoderConfig)
|
||||
assert isinstance(depth_thread.video_encoder, DepthEncoderConfig)
|
||||
assert depth_thread.is_depth is True
|
||||
assert rgb_thread.is_depth is False
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
@@ -94,7 +94,7 @@ def test_image_array_to_pil_image_pytorch_format(img_array_factory):
|
||||
|
||||
def test_image_array_to_pil_image_single_channel(img_array_factory):
|
||||
img_array = img_array_factory(channels=1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
with pytest.raises(ValueError, match="Unsupported single-channel image dtype"):
|
||||
image_array_to_pil_image(img_array)
|
||||
|
||||
|
||||
|
||||
@@ -61,9 +61,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -112,9 +110,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -146,9 +142,7 @@ class TestCameraEncoderThread:
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(as_strings=True),
|
||||
video_encoder=enc_cfg,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -391,7 +385,8 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
# Verify codec options include thread tuning for libsvtav1 (lp=…)
|
||||
thread = encoder._threads[f"{OBS_IMAGES}.cam"]
|
||||
assert "svtav1-params" in thread.codec_options or "threads" in thread.codec_options
|
||||
codec_opts = thread.video_encoder.get_codec_options(encoder_threads=thread.encoder_threads)
|
||||
assert "svtav1-params" in codec_opts or "threads" in codec_opts
|
||||
|
||||
# Feed some frames and finish to ensure it works end-to-end
|
||||
num_frames = 10
|
||||
|
||||
@@ -26,7 +26,7 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av # noqa: E402
|
||||
|
||||
from lerobot.configs import VALID_VIDEO_CODECS, VideoEncoderConfig
|
||||
from lerobot.configs import VALID_VIDEO_CODECS, DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.datasets.image_writer import write_image
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pyav_utils import get_codec
|
||||
@@ -339,7 +339,7 @@ def _encode_video(
|
||||
) -> Path:
|
||||
imgs_dir = path.parent / f"imgs_{path.stem}"
|
||||
_write_frames(imgs_dir, num_frames=num_frames)
|
||||
encode_video_frames(imgs_dir, path, fps=fps, camera_encoder=cfg, overwrite=True)
|
||||
encode_video_frames(imgs_dir, path, fps=fps, video_encoder=cfg, overwrite=True)
|
||||
return path
|
||||
|
||||
|
||||
@@ -375,7 +375,7 @@ class TestGetVideoInfo:
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.channels"] == 3
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
assert "video.g" not in info
|
||||
assert "video.crf" not in info
|
||||
@@ -385,7 +385,7 @@ class TestGetVideoInfo:
|
||||
def test_merges_encoder_config_as_video_prefixed_entries(self):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", camera_encoder=cfg)
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=cfg)
|
||||
|
||||
assert info["video.g"] == 2
|
||||
assert info["video.crf"] == 30
|
||||
@@ -398,11 +398,16 @@ class TestGetVideoInfo:
|
||||
def test_stream_derived_keys_take_precedence_over_config(self):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", camera_encoder=cfg)
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=cfg)
|
||||
|
||||
assert info["video.codec"] # populated from stream, not from config's vcodec
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
|
||||
def test_depth_encoder_config_sets_is_depth_map_true(self):
|
||||
"""A ``DepthEncoderConfig`` causes ``get_video_info`` to mark the stream as depth."""
|
||||
info = get_video_info(TEST_ARTIFACTS_DIR / "clip_4frames.mp4", video_encoder=DepthEncoderConfig())
|
||||
assert info["is_depth_map"] is True
|
||||
|
||||
|
||||
class TestEncodeVideoFrames:
|
||||
@require_libsvtav1
|
||||
@@ -461,7 +466,7 @@ class TestEncodeVideoFrames:
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=4, crf=25, preset=10)
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg)
|
||||
|
||||
info = get_video_info(video_path, camera_encoder=cfg)
|
||||
info = get_video_info(video_path, video_encoder=cfg)
|
||||
|
||||
# Stream-derived
|
||||
assert info["video.height"] == 64
|
||||
@@ -470,7 +475,7 @@ class TestEncodeVideoFrames:
|
||||
assert info["video.codec"] == "av1"
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
# Encoder config
|
||||
assert info["video.g"] == 4
|
||||
@@ -488,14 +493,14 @@ class TestReencodeVideo:
|
||||
src = TEST_ARTIFACTS_DIR / "clip_4frames.mp4"
|
||||
out = tmp_path / "reencoded.mp4"
|
||||
cfg = VideoEncoderConfig(vcodec="h264", g=6, crf=23, pix_fmt="yuv444p")
|
||||
reencode_video(src, out, camera_encoder=cfg, overwrite=True)
|
||||
reencode_video(src, out, video_encoder=cfg, overwrite=True)
|
||||
|
||||
assert out.exists()
|
||||
with av.open(str(out)) as container:
|
||||
n_frames = sum(1 for _ in container.decode(video=0))
|
||||
assert n_frames == 4
|
||||
|
||||
info = get_video_info(out, camera_encoder=cfg)
|
||||
info = get_video_info(out, video_encoder=cfg)
|
||||
assert info["video.codec"] == "h264"
|
||||
assert info["video.pix_fmt"] == "yuv444p"
|
||||
assert info["video.height"] == 64
|
||||
|
||||
Vendored
+13
-1
@@ -39,12 +39,24 @@ DUMMY_VIDEO_INFO = {
|
||||
"video.crf": 30,
|
||||
"video.preset": 12,
|
||||
"video.fast_decode": 0,
|
||||
"video.is_depth_map": False,
|
||||
"is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO},
|
||||
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": DUMMY_VIDEO_INFO},
|
||||
}
|
||||
DUMMY_DEPTH_VIDEO_INFO = {
|
||||
**DUMMY_VIDEO_INFO,
|
||||
"is_depth_map": True,
|
||||
}
|
||||
DUMMY_DEPTH_CAMERA_FEATURES = {
|
||||
"laptop_depth": {
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": DUMMY_DEPTH_VIDEO_INFO,
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH = {**DUMMY_CAMERA_FEATURES, **DUMMY_DEPTH_CAMERA_FEATURES}
|
||||
DUMMY_CHW = (3, 96, 128)
|
||||
DUMMY_HWC = (96, 128, 3)
|
||||
|
||||
Vendored
+43
-2
@@ -38,8 +38,6 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_PATH,
|
||||
DatasetInfo,
|
||||
)
|
||||
from lerobot.datasets.video_utils import encode_video_frames
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
@@ -47,6 +45,44 @@ from tests.fixtures.constants import (
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_ROBOT_TYPE,
|
||||
)
|
||||
from lerobot.datasets.video_utils import encode_video_frames
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
|
||||
|
||||
def add_frames(
|
||||
dataset: LeRobotDataset, num_frames: int
|
||||
) -> None:
|
||||
"""Append ``num_frames`` synthetic frames to ``dataset``.
|
||||
|
||||
Generates per-feature payloads from ``dataset.meta``: uint16 depth ramps for
|
||||
keys in ``dataset.meta.depth_keys``, uint8 random noise for video/image keys,
|
||||
and float32 zeros for everything else. ``DEFAULT_FEATURES`` (timestamp,
|
||||
frame_index, ...) are auto-populated by ``add_frame`` and skipped here.
|
||||
"""
|
||||
if video_keys is None:
|
||||
video_keys = dataset.meta.video_keys
|
||||
depth_keys = set(dataset.meta.depth_keys)
|
||||
# Smooth gradient base reused per (H, W) to keep depth frames cheap to
|
||||
# encode (HEVC Main 12 hates white noise).
|
||||
_depth_base_cache: dict[tuple[int, int], np.ndarray] = {}
|
||||
for i in range(num_frames):
|
||||
frame: dict = {"task": "test"}
|
||||
for key, ft in dataset.meta.features.items():
|
||||
if key in DEFAULT_FEATURES:
|
||||
continue
|
||||
shape = ft["shape"]
|
||||
if key in depth_keys:
|
||||
h, w, _ = shape
|
||||
base = _depth_base_cache.setdefault(
|
||||
(h, w),
|
||||
np.linspace(100.0, 10_000.0, h * w, dtype=np.float32).reshape(h, w, 1),
|
||||
)
|
||||
frame[key] = (base + 50.0 * i).clip(0, 65535).astype(np.uint16)
|
||||
elif key in video_keys:
|
||||
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
|
||||
else:
|
||||
frame[key] = np.zeros(shape, dtype=np.float32)
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
class LeRobotDatasetFactory(Protocol):
|
||||
@@ -485,10 +521,14 @@ def lerobot_dataset_factory(
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
camera_features: dict | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
# Instantiate objects
|
||||
if info is None:
|
||||
info_kwargs = {}
|
||||
if camera_features is not None:
|
||||
info_kwargs["camera_features"] = camera_features
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
@@ -496,6 +536,7 @@ def lerobot_dataset_factory(
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
chunks_size=chunks_size,
|
||||
**info_kwargs,
|
||||
)
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info.features)
|
||||
|
||||
Reference in New Issue
Block a user