feat(datasets): expose video codec option for dataset recording (#2771)

* expose codec options + add tests

* pre-commit run -a
This commit is contained in:
Leo Tronchon
2026-01-08 18:06:39 +01:00
committed by GitHub
parent 242b65d2df
commit 8b6fc0ae05
3 changed files with 125 additions and 3 deletions
+100
View File
@@ -31,8 +31,10 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.image_writer import image_array_to_pil_image
from lerobot.datasets.lerobot_dataset import (
VALID_VIDEO_CODECS,
LeRobotDataset,
MultiLeRobotDataset,
_encode_video_worker,
)
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
@@ -1292,3 +1294,101 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
frame = loaded_dataset[idx]
expected_ep = idx // frames_per_episode
assert frame["episode_index"].item() == expected_ep
def test_encode_video_worker_forwards_vcodec(tmp_path):
"""Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames."""
from unittest.mock import patch
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
# Create the expected directory structure
video_key = "observation.images.laptop"
episode_index = 0
frame_index = 0
fpath = DEFAULT_IMAGE_PATH.format(
image_key=video_key, episode_index=episode_index, frame_index=frame_index
)
img_dir = tmp_path / Path(fpath).parent
img_dir.mkdir(parents=True, exist_ok=True)
# Create a dummy image file
dummy_img = Image.new("RGB", (64, 64), color="red")
dummy_img.save(img_dir / "frame-000000.png")
# Track what vcodec was passed to encode_video_frames
captured_kwargs = {}
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
captured_kwargs.update(kwargs)
# Create a dummy output file so the worker doesn't fail
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
Path(video_path).touch()
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
# Test with h264 codec
_encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264")
assert "vcodec" in captured_kwargs
assert captured_kwargs["vcodec"] == "h264"
def test_encode_video_worker_default_vcodec(tmp_path):
"""Test that _encode_video_worker uses libsvtav1 as the default codec."""
from unittest.mock import patch
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
# Create the expected directory structure
video_key = "observation.images.laptop"
episode_index = 0
frame_index = 0
fpath = DEFAULT_IMAGE_PATH.format(
image_key=video_key, episode_index=episode_index, frame_index=frame_index
)
img_dir = tmp_path / Path(fpath).parent
img_dir.mkdir(parents=True, exist_ok=True)
# Create a dummy image file
dummy_img = Image.new("RGB", (64, 64), color="red")
dummy_img.save(img_dir / "frame-000000.png")
# Track what vcodec was passed to encode_video_frames
captured_kwargs = {}
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
captured_kwargs.update(kwargs)
# Create a dummy output file so the worker doesn't fail
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
Path(video_path).touch()
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
# Test with default codec (no vcodec specified)
_encode_video_worker(video_key, episode_index, tmp_path, fps=30)
assert "vcodec" in captured_kwargs
assert captured_kwargs["vcodec"] == "libsvtav1"
def test_lerobot_dataset_vcodec_validation():
"""Test that LeRobotDataset validates the vcodec parameter."""
# Test that invalid vcodec raises ValueError
with pytest.raises(ValueError, match="Invalid vcodec"):
LeRobotDataset.__new__(LeRobotDataset) # bypass __init__ to test validation directly
# Actually test via create since it's easier
LeRobotDataset.create(
repo_id="test/invalid_codec",
fps=30,
features={"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}},
vcodec="invalid_codec",
)
def test_valid_video_codecs_constant():
"""Test that VALID_VIDEO_CODECS contains the expected codecs."""
assert "h264" in VALID_VIDEO_CODECS
assert "hevc" in VALID_VIDEO_CODECS
assert "libsvtav1" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 3