mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
feat(datasets): expose video codec option for dataset recording (#2771)
* expose codec options + add tests * pre-commit run -a
This commit is contained in:
@@ -31,8 +31,10 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
VALID_VIDEO_CODECS,
|
||||
LeRobotDataset,
|
||||
MultiLeRobotDataset,
|
||||
_encode_video_worker,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
@@ -1292,3 +1294,101 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
frame = loaded_dataset[idx]
|
||||
expected_ep = idx // frames_per_episode
|
||||
assert frame["episode_index"].item() == expected_ep
|
||||
|
||||
|
||||
def test_encode_video_worker_forwards_vcodec(tmp_path):
|
||||
"""Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
|
||||
# Create the expected directory structure
|
||||
video_key = "observation.images.laptop"
|
||||
episode_index = 0
|
||||
frame_index = 0
|
||||
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=video_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a dummy image file
|
||||
dummy_img = Image.new("RGB", (64, 64), color="red")
|
||||
dummy_img.save(img_dir / "frame-000000.png")
|
||||
|
||||
# Track what vcodec was passed to encode_video_frames
|
||||
captured_kwargs = {}
|
||||
|
||||
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
# Create a dummy output file so the worker doesn't fail
|
||||
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
|
||||
# Test with h264 codec
|
||||
_encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264")
|
||||
|
||||
assert "vcodec" in captured_kwargs
|
||||
assert captured_kwargs["vcodec"] == "h264"
|
||||
|
||||
|
||||
def test_encode_video_worker_default_vcodec(tmp_path):
|
||||
"""Test that _encode_video_worker uses libsvtav1 as the default codec."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
|
||||
# Create the expected directory structure
|
||||
video_key = "observation.images.laptop"
|
||||
episode_index = 0
|
||||
frame_index = 0
|
||||
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=video_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a dummy image file
|
||||
dummy_img = Image.new("RGB", (64, 64), color="red")
|
||||
dummy_img.save(img_dir / "frame-000000.png")
|
||||
|
||||
# Track what vcodec was passed to encode_video_frames
|
||||
captured_kwargs = {}
|
||||
|
||||
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
# Create a dummy output file so the worker doesn't fail
|
||||
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
|
||||
# Test with default codec (no vcodec specified)
|
||||
_encode_video_worker(video_key, episode_index, tmp_path, fps=30)
|
||||
|
||||
assert "vcodec" in captured_kwargs
|
||||
assert captured_kwargs["vcodec"] == "libsvtav1"
|
||||
|
||||
|
||||
def test_lerobot_dataset_vcodec_validation():
|
||||
"""Test that LeRobotDataset validates the vcodec parameter."""
|
||||
# Test that invalid vcodec raises ValueError
|
||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||
LeRobotDataset.__new__(LeRobotDataset) # bypass __init__ to test validation directly
|
||||
# Actually test via create since it's easier
|
||||
LeRobotDataset.create(
|
||||
repo_id="test/invalid_codec",
|
||||
fps=30,
|
||||
features={"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}},
|
||||
vcodec="invalid_codec",
|
||||
)
|
||||
|
||||
|
||||
def test_valid_video_codecs_constant():
|
||||
"""Test that VALID_VIDEO_CODECS contains the expected codecs."""
|
||||
assert "h264" in VALID_VIDEO_CODECS
|
||||
assert "hevc" in VALID_VIDEO_CODECS
|
||||
assert "libsvtav1" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 3
|
||||
|
||||
Reference in New Issue
Block a user