test(reencode dataset): adding missing test for reencode dataset

This commit is contained in:
CarolinePascal
2026-05-17 23:23:10 +02:00
parent 5547757cea
commit 6e01006d94
2 changed files with 57 additions and 9 deletions
+42
View File
@@ -23,6 +23,7 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.configs import VideoEncoderConfig
from lerobot.datasets.dataset_tools import (
add_features,
@@ -31,9 +32,12 @@ from lerobot.datasets.dataset_tools import (
merge_datasets,
modify_features,
modify_tasks,
reencode_dataset,
remove_feature,
split_dataset,
)
from lerobot.datasets.io_utils import load_info
from tests.datasets.test_video_encoding import _add_frames, require_h264, require_libsvtav1
@pytest.fixture
@@ -1326,3 +1330,41 @@ def test_convert_image_to_video_dataset_subset_episodes(tmp_path):
if output_dir.exists():
shutil.rmtree(output_dir)
# ─── reencode_dataset ─────────────────────────────────────────────────
@require_libsvtav1
@require_h264
def test_reencode_dataset_multi_key_multiprocessing(
tmp_path, empty_lerobot_dataset_factory, features_factory
):
"""Re-encode a two-camera dataset with num_workers=2 and verify metadata refresh."""
features = features_factory(use_videos=True)
initial_cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "ds",
features=features,
use_videos=True,
camera_encoder=initial_cfg,
)
_add_frames(dataset, num_frames=4)
dataset.save_episode()
_add_frames(dataset, num_frames=4)
dataset.save_episode()
dataset.finalize()
assert len(dataset.meta.video_keys) == 2
target_cfg = VideoEncoderConfig(vcodec="h264", g=6, crf=23, pix_fmt="yuv420p")
result = reencode_dataset(dataset, camera_encoder=target_cfg, num_workers=2)
assert result is dataset
persisted_info = load_info(dataset.root)
for vk in dataset.meta.video_keys:
persisted_encoder = VideoEncoderConfig.from_video_info(persisted_info.features[vk].get("info", {}))
assert persisted_encoder == target_cfg
+15 -9
View File
@@ -348,16 +348,22 @@ def _read_feature_info(dataset: LeRobotDataset) -> dict:
return info["features"][VIDEO_KEY]["info"]
def _add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
shape = dataset.meta.features[VIDEO_KEY]["shape"]
def _add_frames(dataset: LeRobotDataset, num_frames: int, video_keys: list[str] | None = None) -> None:
from lerobot.utils.constants import DEFAULT_FEATURES
if video_keys is None:
video_keys = dataset.meta.video_keys
for _ in range(num_frames):
dataset.add_frame(
{
VIDEO_KEY: np.random.randint(0, 256, shape, dtype=np.uint8),
"action": np.zeros(2, dtype=np.float32),
"task": "test",
}
)
frame: dict = {"task": "test"}
for key, ft in dataset.meta.features.items():
if key in DEFAULT_FEATURES:
continue
shape = ft["shape"]
if key in video_keys:
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
else:
frame[key] = np.zeros(shape, dtype=np.float32)
dataset.add_frame(frame)
class TestGetVideoInfo: