feat(depth): plumb DepthEncoderConfig through LeRobotDataset and DatasetWriter

This commit is contained in:
CarolinePascal
2026-05-19 22:50:19 +02:00
parent 0cc5162078
commit b4c31f0f67
5 changed files with 47 additions and 25 deletions
+7 -7
View File
@@ -53,8 +53,8 @@ def _make_frame(features: dict, task: str = "Dummy task") -> dict:
# ── Existing encode_video_worker tests ───────────────────────────────
def test_encode_video_worker_forwards_camera_encoder(tmp_path):
"""_encode_video_worker forwards camera_encoder to encode_video_frames."""
def test_encode_video_worker_forwards_video_encoder(tmp_path):
"""_encode_video_worker forwards video_encoder to encode_video_frames."""
video_key = "observation.images.laptop"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
img_dir = tmp_path / Path(fpath).parent
@@ -74,16 +74,16 @@ def test_encode_video_worker_forwards_camera_encoder(tmp_path):
0,
tmp_path,
fps=30,
camera_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
video_encoder=VideoEncoderConfig(vcodec="h264", preset=None),
encoder_threads=4,
)
assert captured_kwargs["camera_encoder"].vcodec == "h264"
assert captured_kwargs["video_encoder"].vcodec == "h264"
assert captured_kwargs["encoder_threads"] == 4
def test_encode_video_worker_default_camera_encoder(tmp_path):
"""_encode_video_worker passes None camera_encoder which encode_video_frames defaults."""
def test_encode_video_worker_default_video_encoder(tmp_path):
"""_encode_video_worker passes None video_encoder which encode_video_frames defaults."""
video_key = "observation.images.laptop"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
img_dir = tmp_path / Path(fpath).parent
@@ -100,7 +100,7 @@ def test_encode_video_worker_default_camera_encoder(tmp_path):
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
_encode_video_worker(video_key, 0, tmp_path, fps=30)
assert captured_kwargs["camera_encoder"] is None
assert captured_kwargs["video_encoder"] is None
assert captured_kwargs["encoder_threads"] is None