From d777359662285feae06b5f71164ad2be01991445 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Sun, 26 Apr 2026 14:01:25 +0200 Subject: [PATCH] feat(depth): plumb DepthEncoderConfig through LeRobotDataset and DatasetWriter --- src/lerobot/datasets/dataset_writer.py | 7 ++++++ src/lerobot/datasets/lerobot_dataset.py | 10 ++++++++ tests/datasets/test_lerobot_dataset.py | 31 +++++++++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py index 4841f7d3b..48defab85 100644 --- a/src/lerobot/datasets/dataset_writer.py +++ b/src/lerobot/datasets/dataset_writer.py @@ -51,6 +51,7 @@ from .utils import ( update_chunk_file_indices, ) from .video_utils import ( + DepthEncoderConfig, StreamingVideoEncoder, VideoEncoderConfig, concatenate_video_files, @@ -100,6 +101,7 @@ class DatasetWriter: batch_encoding_size: int, streaming_encoder: StreamingVideoEncoder | None = None, initial_frames: int = 0, + depth_encoder_config: DepthEncoderConfig | None = None, ): """Initialize the writer with metadata, codec, and encoder config. @@ -115,14 +117,19 @@ class DatasetWriter: streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder` for real-time encoding. ``None`` disables streaming mode. initial_frames: Starting frame count (non-zero when resuming). + depth_encoder_config: Optional depth-map encoder config used in + place of ``camera_encoder_config`` for keys present in + ``meta.depth_keys``. """ self._meta = meta self._root = root self._camera_encoder_config = camera_encoder_config + self._depth_encoder_config = depth_encoder_config self._encoder_threads = encoder_threads self._batch_encoding_size = batch_encoding_size self._streaming_encoder = streaming_encoder + # Writer state self.image_writer: AsyncImageWriter | None = None self.episode_buffer: dict = self._create_episode_buffer() diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index ea1cfc424..ce132557c 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -35,6 +35,7 @@ from .utils import ( is_valid_version, ) from .video_utils import ( + DepthEncoderConfig, StreamingVideoEncoder, VideoEncoderConfig, get_safe_default_video_backend, @@ -59,6 +60,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return_uint8: bool = False, batch_encoding_size: int = 1, camera_encoder_config: VideoEncoderConfig | None = None, + depth_encoder_config: DepthEncoderConfig | None = None, encoder_threads: int | None = None, streaming_encoding: bool = False, encoder_queue_maxsize: int = 30, @@ -207,6 +209,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if camera_encoder_config is None: camera_encoder_config = VideoEncoderConfig() self._camera_encoder_config = camera_encoder_config + self._depth_encoder_config = depth_encoder_config self._encoder_threads = encoder_threads if self._requested_root is not None: @@ -261,6 +264,7 @@ class LeRobotDataset(torch.utils.data.Dataset): meta=self.meta, root=self.root, camera_encoder_config=self._camera_encoder_config, + depth_encoder_config=self._depth_encoder_config, encoder_threads=self._encoder_threads, batch_encoding_size=batch_encoding_size, streaming_encoder=streaming_enc, @@ -626,6 +630,7 @@ class LeRobotDataset(torch.utils.data.Dataset): video_backend: str | None = None, batch_encoding_size: int = 1, camera_encoder_config: VideoEncoderConfig | None = None, + depth_encoder_config: DepthEncoderConfig | None = None, metadata_buffer_size: int = 10, streaming_encoding: bool = False, encoder_queue_maxsize: int = 30, @@ -697,6 +702,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._return_uint8 = False obj._batch_encoding_size = batch_encoding_size obj._camera_encoder_config = camera_encoder_config + obj._depth_encoder_config = depth_encoder_config obj._encoder_threads = encoder_threads # Reader is lazily created on first access (write-only mode) @@ -711,6 +717,7 @@ class LeRobotDataset(torch.utils.data.Dataset): meta=obj.meta, root=obj.root, camera_encoder_config=camera_encoder_config, + depth_encoder_config=depth_encoder_config, encoder_threads=encoder_threads, batch_encoding_size=batch_encoding_size, streaming_encoder=streaming_enc, @@ -734,6 +741,7 @@ class LeRobotDataset(torch.utils.data.Dataset): video_backend: str | None = None, batch_encoding_size: int = 1, camera_encoder_config: VideoEncoderConfig | None = None, + depth_encoder_config: DepthEncoderConfig | None = None, encoder_threads: int | None = None, image_writer_processes: int = 0, image_writer_threads: int = 0, @@ -804,6 +812,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if camera_encoder_config is None: camera_encoder_config = VideoEncoderConfig() obj._camera_encoder_config = camera_encoder_config + obj._depth_encoder_config = depth_encoder_config obj._encoder_threads = encoder_threads obj.root = obj.meta.root @@ -819,6 +828,7 @@ class LeRobotDataset(torch.utils.data.Dataset): meta=obj.meta, root=obj.root, camera_encoder_config=camera_encoder_config, + depth_encoder_config=depth_encoder_config, encoder_threads=encoder_threads, batch_encoding_size=batch_encoding_size, streaming_encoder=streaming_enc, diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py index f3bda037f..92348b269 100644 --- a/tests/datasets/test_lerobot_dataset.py +++ b/tests/datasets/test_lerobot_dataset.py @@ -436,6 +436,37 @@ def test_add_frame_works_in_write_mode(tmp_path): dataset.add_frame(_make_frame()) # should not raise +# ── Depth-feature plumbing ─────────────────────────────────────────── + + +_DEPTH_FEATURES = { + **SIMPLE_FEATURES, + "observation.depth": { + "dtype": "video", + "shape": (32, 32), + "names": ["height", "width"], + "info": {"video.is_depth_map": True}, + }, +} + + +def test_create_with_depth_streaming_succeeds(tmp_path): + """A depth dataset with streaming_encoding=True is created in write mode.""" + from lerobot.datasets.video_utils import DepthEncoderConfig + + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, + fps=DEFAULT_FPS, + features=_DEPTH_FEATURES, + root=tmp_path / "depth_ds", + depth_encoder_config=DepthEncoderConfig(), + streaming_encoding=True, + ) + assert isinstance(dataset.writer, DatasetWriter) + assert dataset.meta.depth_keys == ["observation.depth"] + assert dataset._depth_encoder_config is not None + + # ── Resume mode ──────────────────────────────────────────────────────