mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
test(existing): adapting existing tests
This commit is contained in:
@@ -25,6 +25,7 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
|
|||||||
|
|
||||||
from lerobot.datasets.dataset_tools import (
|
from lerobot.datasets.dataset_tools import (
|
||||||
add_features,
|
add_features,
|
||||||
|
convert_image_to_video_dataset,
|
||||||
delete_episodes,
|
delete_episodes,
|
||||||
merge_datasets,
|
merge_datasets,
|
||||||
modify_features,
|
modify_features,
|
||||||
@@ -32,7 +33,7 @@ from lerobot.datasets.dataset_tools import (
|
|||||||
remove_feature,
|
remove_feature,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -1246,10 +1247,12 @@ def test_convert_image_to_video_dataset(tmp_path):
|
|||||||
dataset=source_dataset,
|
dataset=source_dataset,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
repo_id="lerobot/pusht_video",
|
repo_id="lerobot/pusht_video",
|
||||||
vcodec="libsvtav1",
|
camera_encoder_config=VideoEncoderConfig(
|
||||||
pix_fmt="yuv420p",
|
vcodec="libsvtav1",
|
||||||
g=2,
|
pix_fmt="yuv420p",
|
||||||
crf=30,
|
g=2,
|
||||||
|
crf=30,
|
||||||
|
),
|
||||||
episode_indices=[0, 1],
|
episode_indices=[0, 1],
|
||||||
num_workers=2,
|
num_workers=2,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
|
|||||||
from lerobot.datasets.dataset_writer import _encode_video_worker
|
from lerobot.datasets.dataset_writer import _encode_video_worker
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||||
|
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||||
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
|
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
|
||||||
|
|
||||||
SIMPLE_FEATURES = {
|
SIMPLE_FEATURES = {
|
||||||
@@ -52,8 +53,8 @@ def _make_frame(features: dict, task: str = "Dummy task") -> dict:
|
|||||||
# ── Existing encode_video_worker tests ───────────────────────────────
|
# ── Existing encode_video_worker tests ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def test_encode_video_worker_forwards_vcodec(tmp_path):
|
def test_encode_video_worker_forwards_camera_encoder_config(tmp_path):
|
||||||
"""_encode_video_worker correctly forwards the vcodec parameter."""
|
"""_encode_video_worker forwards camera_encoder_config to encode_video_frames."""
|
||||||
video_key = "observation.images.laptop"
|
video_key = "observation.images.laptop"
|
||||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||||
img_dir = tmp_path / Path(fpath).parent
|
img_dir = tmp_path / Path(fpath).parent
|
||||||
@@ -68,13 +69,21 @@ def test_encode_video_worker_forwards_vcodec(tmp_path):
|
|||||||
Path(video_path).touch()
|
Path(video_path).touch()
|
||||||
|
|
||||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||||
_encode_video_worker(video_key, 0, tmp_path, fps=30, vcodec="h264")
|
_encode_video_worker(
|
||||||
|
video_key,
|
||||||
|
0,
|
||||||
|
tmp_path,
|
||||||
|
fps=30,
|
||||||
|
camera_encoder_config=VideoEncoderConfig(vcodec="h264", preset=None),
|
||||||
|
encoder_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
assert captured_kwargs["vcodec"] == "h264"
|
assert captured_kwargs["camera_encoder_config"].vcodec == "h264"
|
||||||
|
assert captured_kwargs["encoder_threads"] == 4
|
||||||
|
|
||||||
|
|
||||||
def test_encode_video_worker_default_vcodec(tmp_path):
|
def test_encode_video_worker_default_camera_encoder_config(tmp_path):
|
||||||
"""_encode_video_worker uses libsvtav1 as the default codec."""
|
"""_encode_video_worker passes None camera_encoder_config which encode_video_frames defaults."""
|
||||||
video_key = "observation.images.laptop"
|
video_key = "observation.images.laptop"
|
||||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||||
img_dir = tmp_path / Path(fpath).parent
|
img_dir = tmp_path / Path(fpath).parent
|
||||||
@@ -91,7 +100,8 @@ def test_encode_video_worker_default_vcodec(tmp_path):
|
|||||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||||
_encode_video_worker(video_key, 0, tmp_path, fps=30)
|
_encode_video_worker(video_key, 0, tmp_path, fps=30)
|
||||||
|
|
||||||
assert captured_kwargs["vcodec"] == "libsvtav1"
|
assert captured_kwargs["camera_encoder_config"] is None
|
||||||
|
assert captured_kwargs["encoder_threads"] is None
|
||||||
|
|
||||||
|
|
||||||
# ── add_frame contracts ──────────────────────────────────────────────
|
# ── add_frame contracts ──────────────────────────────────────────────
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from lerobot.datasets.utils import (
|
|||||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
create_branch,
|
create_branch,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS
|
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS, VideoEncoderConfig
|
||||||
from lerobot.envs.factory import make_env_config
|
from lerobot.envs.factory import make_env_config
|
||||||
from lerobot.policies.factory import make_policy_config
|
from lerobot.policies.factory import make_policy_config
|
||||||
from lerobot.robots import make_robot_from_config
|
from lerobot.robots import make_robot_from_config
|
||||||
@@ -1470,17 +1470,9 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
|||||||
|
|
||||||
|
|
||||||
def test_lerobot_dataset_vcodec_validation():
|
def test_lerobot_dataset_vcodec_validation():
|
||||||
"""Test that LeRobotDataset validates the vcodec parameter."""
|
"""Invalid vcodec in encoder config is rejected at construction time."""
|
||||||
# Test that invalid vcodec raises ValueError
|
|
||||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||||
LeRobotDataset.__new__(LeRobotDataset) # bypass __init__ to test validation directly
|
VideoEncoderConfig(vcodec="invalid_codec")
|
||||||
# Actually test via create since it's easier
|
|
||||||
LeRobotDataset.create(
|
|
||||||
repo_id="test/invalid_codec",
|
|
||||||
fps=30,
|
|
||||||
features={"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}},
|
|
||||||
vcodec="invalid_codec",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_valid_video_codecs_constant():
|
def test_valid_video_codecs_constant():
|
||||||
|
|||||||
Reference in New Issue
Block a user