mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
feat(depth): plumb DepthEncoderConfig through LeRobotDataset and DatasetWriter
This commit is contained in:
@@ -51,6 +51,7 @@ from .utils import (
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
concatenate_video_files,
|
||||
@@ -100,6 +101,7 @@ class DatasetWriter:
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
initial_frames: int = 0,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
):
|
||||
"""Initialize the writer with metadata, codec, and encoder config.
|
||||
|
||||
@@ -115,14 +117,19 @@ class DatasetWriter:
|
||||
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
|
||||
for real-time encoding. ``None`` disables streaming mode.
|
||||
initial_frames: Starting frame count (non-zero when resuming).
|
||||
depth_encoder_config: Optional depth-map encoder config used in
|
||||
place of ``camera_encoder_config`` for keys present in
|
||||
``meta.depth_keys``.
|
||||
"""
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._camera_encoder_config = camera_encoder_config
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
|
||||
|
||||
# Writer state
|
||||
self.image_writer: AsyncImageWriter | None = None
|
||||
self.episode_buffer: dict = self._create_episode_buffer()
|
||||
|
||||
@@ -35,6 +35,7 @@ from .utils import (
|
||||
is_valid_version,
|
||||
)
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
get_safe_default_video_backend,
|
||||
@@ -59,6 +60,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -207,6 +209,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
self._camera_encoder_config = camera_encoder_config
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
if self._requested_root is not None:
|
||||
@@ -261,6 +264,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
camera_encoder_config=self._camera_encoder_config,
|
||||
depth_encoder_config=self._depth_encoder_config,
|
||||
encoder_threads=self._encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -626,6 +630,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -697,6 +702,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._depth_encoder_config = depth_encoder_config
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
@@ -711,6 +717,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -734,6 +741,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
@@ -804,6 +812,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._depth_encoder_config = depth_encoder_config
|
||||
obj._encoder_threads = encoder_threads
|
||||
obj.root = obj.meta.root
|
||||
|
||||
@@ -819,6 +828,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
@@ -436,6 +436,37 @@ def test_add_frame_works_in_write_mode(tmp_path):
|
||||
dataset.add_frame(_make_frame()) # should not raise
|
||||
|
||||
|
||||
# ── Depth-feature plumbing ───────────────────────────────────────────
|
||||
|
||||
|
||||
_DEPTH_FEATURES = {
|
||||
**SIMPLE_FEATURES,
|
||||
"observation.depth": {
|
||||
"dtype": "video",
|
||||
"shape": (32, 32),
|
||||
"names": ["height", "width"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_create_with_depth_streaming_succeeds(tmp_path):
|
||||
"""A depth dataset with streaming_encoding=True is created in write mode."""
|
||||
from lerobot.datasets.video_utils import DepthEncoderConfig
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=_DEPTH_FEATURES,
|
||||
root=tmp_path / "depth_ds",
|
||||
depth_encoder_config=DepthEncoderConfig(),
|
||||
streaming_encoding=True,
|
||||
)
|
||||
assert isinstance(dataset.writer, DatasetWriter)
|
||||
assert dataset.meta.depth_keys == ["observation.depth"]
|
||||
assert dataset._depth_encoder_config is not None
|
||||
|
||||
|
||||
# ── Resume mode ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user