test(existing): adapting existing tests

This commit is contained in:
CarolinePascal
2026-04-24 17:15:24 +02:00
parent 7f624adcc5
commit 57a619ab02
3 changed files with 28 additions and 23 deletions
+8 -5
View File
@@ -25,6 +25,7 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
from lerobot.datasets.dataset_tools import (
add_features,
convert_image_to_video_dataset,
delete_episodes,
merge_datasets,
modify_features,
@@ -32,7 +33,7 @@ from lerobot.datasets.dataset_tools import (
remove_feature,
split_dataset,
)
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
from lerobot.datasets.video_utils import VideoEncoderConfig
@pytest.fixture
@@ -1246,10 +1247,12 @@ def test_convert_image_to_video_dataset(tmp_path):
dataset=source_dataset,
output_dir=output_dir,
repo_id="lerobot/pusht_video",
vcodec="libsvtav1",
pix_fmt="yuv420p",
g=2,
crf=30,
camera_encoder_config=VideoEncoderConfig(
vcodec="libsvtav1",
pix_fmt="yuv420p",
g=2,
crf=30,
),
episode_indices=[0, 1],
num_workers=2,
)
+17 -7
View File
@@ -28,6 +28,7 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
from lerobot.datasets.dataset_writer import _encode_video_worker
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
from lerobot.datasets.video_utils import VideoEncoderConfig
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
SIMPLE_FEATURES = {
@@ -52,8 +53,8 @@ def _make_frame(features: dict, task: str = "Dummy task") -> dict:
# ── Existing encode_video_worker tests ───────────────────────────────
def test_encode_video_worker_forwards_vcodec(tmp_path):
"""_encode_video_worker correctly forwards the vcodec parameter."""
def test_encode_video_worker_forwards_camera_encoder_config(tmp_path):
"""_encode_video_worker forwards camera_encoder_config to encode_video_frames."""
video_key = "observation.images.laptop"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
img_dir = tmp_path / Path(fpath).parent
@@ -68,13 +69,21 @@ def test_encode_video_worker_forwards_vcodec(tmp_path):
Path(video_path).touch()
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
_encode_video_worker(video_key, 0, tmp_path, fps=30, vcodec="h264")
_encode_video_worker(
video_key,
0,
tmp_path,
fps=30,
camera_encoder_config=VideoEncoderConfig(vcodec="h264", preset=None),
encoder_threads=4,
)
assert captured_kwargs["vcodec"] == "h264"
assert captured_kwargs["camera_encoder_config"].vcodec == "h264"
assert captured_kwargs["encoder_threads"] == 4
def test_encode_video_worker_default_vcodec(tmp_path):
"""_encode_video_worker uses libsvtav1 as the default codec."""
def test_encode_video_worker_default_camera_encoder_config(tmp_path):
"""_encode_video_worker passes None camera_encoder_config which encode_video_frames defaults."""
video_key = "observation.images.laptop"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
img_dir = tmp_path / Path(fpath).parent
@@ -91,7 +100,8 @@ def test_encode_video_worker_default_vcodec(tmp_path):
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
_encode_video_worker(video_key, 0, tmp_path, fps=30)
assert captured_kwargs["vcodec"] == "libsvtav1"
assert captured_kwargs["camera_encoder_config"] is None
assert captured_kwargs["encoder_threads"] is None
# ── add_frame contracts ──────────────────────────────────────────────
+3 -11
View File
@@ -43,7 +43,7 @@ from lerobot.datasets.utils import (
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
create_branch,
)
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS, VideoEncoderConfig
from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy_config
from lerobot.robots import make_robot_from_config
@@ -1470,17 +1470,9 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
def test_lerobot_dataset_vcodec_validation():
"""Test that LeRobotDataset validates the vcodec parameter."""
# Test that invalid vcodec raises ValueError
"""Invalid vcodec in encoder config is rejected at construction time."""
with pytest.raises(ValueError, match="Invalid vcodec"):
LeRobotDataset.__new__(LeRobotDataset) # bypass __init__ to test validation directly
# Actually test via create since it's easier
LeRobotDataset.create(
repo_id="test/invalid_codec",
fps=30,
features={"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}},
vcodec="invalid_codec",
)
VideoEncoderConfig(vcodec="invalid_codec")
def test_valid_video_codecs_constant():