feat(depth): plumb DepthEncoderConfig through LeRobotDataset and DatasetWriter

This commit is contained in:
CarolinePascal
2026-04-26 14:01:25 +02:00
parent 5d0a20bd9c
commit d777359662
3 changed files with 48 additions and 0 deletions
+7
View File
@@ -51,6 +51,7 @@ from .utils import (
update_chunk_file_indices,
)
from .video_utils import (
DepthEncoderConfig,
StreamingVideoEncoder,
VideoEncoderConfig,
concatenate_video_files,
@@ -100,6 +101,7 @@ class DatasetWriter:
batch_encoding_size: int,
streaming_encoder: StreamingVideoEncoder | None = None,
initial_frames: int = 0,
depth_encoder_config: DepthEncoderConfig | None = None,
):
"""Initialize the writer with metadata, codec, and encoder config.
@@ -115,14 +117,19 @@ class DatasetWriter:
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
for real-time encoding. ``None`` disables streaming mode.
initial_frames: Starting frame count (non-zero when resuming).
depth_encoder_config: Optional depth-map encoder config used in
place of ``camera_encoder_config`` for keys present in
``meta.depth_keys``.
"""
self._meta = meta
self._root = root
self._camera_encoder_config = camera_encoder_config
self._depth_encoder_config = depth_encoder_config
self._encoder_threads = encoder_threads
self._batch_encoding_size = batch_encoding_size
self._streaming_encoder = streaming_encoder
# Writer state
self.image_writer: AsyncImageWriter | None = None
self.episode_buffer: dict = self._create_episode_buffer()
+10
View File
@@ -35,6 +35,7 @@ from .utils import (
is_valid_version,
)
from .video_utils import (
DepthEncoderConfig,
StreamingVideoEncoder,
VideoEncoderConfig,
get_safe_default_video_backend,
@@ -59,6 +60,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return_uint8: bool = False,
batch_encoding_size: int = 1,
camera_encoder_config: VideoEncoderConfig | None = None,
depth_encoder_config: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
@@ -207,6 +209,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if camera_encoder_config is None:
camera_encoder_config = VideoEncoderConfig()
self._camera_encoder_config = camera_encoder_config
self._depth_encoder_config = depth_encoder_config
self._encoder_threads = encoder_threads
if self._requested_root is not None:
@@ -261,6 +264,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
meta=self.meta,
root=self.root,
camera_encoder_config=self._camera_encoder_config,
depth_encoder_config=self._depth_encoder_config,
encoder_threads=self._encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -626,6 +630,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
batch_encoding_size: int = 1,
camera_encoder_config: VideoEncoderConfig | None = None,
depth_encoder_config: DepthEncoderConfig | None = None,
metadata_buffer_size: int = 10,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
@@ -697,6 +702,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size
obj._camera_encoder_config = camera_encoder_config
obj._depth_encoder_config = depth_encoder_config
obj._encoder_threads = encoder_threads
# Reader is lazily created on first access (write-only mode)
@@ -711,6 +717,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
meta=obj.meta,
root=obj.root,
camera_encoder_config=camera_encoder_config,
depth_encoder_config=depth_encoder_config,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -734,6 +741,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
batch_encoding_size: int = 1,
camera_encoder_config: VideoEncoderConfig | None = None,
depth_encoder_config: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
@@ -804,6 +812,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if camera_encoder_config is None:
camera_encoder_config = VideoEncoderConfig()
obj._camera_encoder_config = camera_encoder_config
obj._depth_encoder_config = depth_encoder_config
obj._encoder_threads = encoder_threads
obj.root = obj.meta.root
@@ -819,6 +828,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
meta=obj.meta,
root=obj.root,
camera_encoder_config=camera_encoder_config,
depth_encoder_config=depth_encoder_config,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
+31
View File
@@ -436,6 +436,37 @@ def test_add_frame_works_in_write_mode(tmp_path):
dataset.add_frame(_make_frame()) # should not raise
# ── Depth-feature plumbing ───────────────────────────────────────────
_DEPTH_FEATURES = {
**SIMPLE_FEATURES,
"observation.depth": {
"dtype": "video",
"shape": (32, 32),
"names": ["height", "width"],
"info": {"video.is_depth_map": True},
},
}
def test_create_with_depth_streaming_succeeds(tmp_path):
"""A depth dataset with streaming_encoding=True is created in write mode."""
from lerobot.datasets.video_utils import DepthEncoderConfig
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID,
fps=DEFAULT_FPS,
features=_DEPTH_FEATURES,
root=tmp_path / "depth_ds",
depth_encoder_config=DepthEncoderConfig(),
streaming_encoding=True,
)
assert isinstance(dataset.writer, DatasetWriter)
assert dataset.meta.depth_keys == ["observation.depth"]
assert dataset._depth_encoder_config is not None
# ── Resume mode ──────────────────────────────────────────────────────