mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
feat(datasets): expose video codec option for dataset recording (#2771)
* expose codec options + add tests * pre-commit run -a
This commit is contained in:
@@ -78,6 +78,7 @@ from lerobot.datasets.video_utils import (
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -540,11 +541,13 @@ class LeRobotDatasetMetadata:
|
||||
return obj
|
||||
|
||||
|
||||
def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path:
|
||||
def _encode_video_worker(
|
||||
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(img_dir, temp_path, fps, overwrite=True)
|
||||
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
@@ -563,6 +566,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -675,8 +679,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
|
||||
encoding is CPU-heavy.
|
||||
"""
|
||||
super().__init__()
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
@@ -688,6 +697,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.delta_indices = None
|
||||
self.batch_encoding_size = batch_encoding_size
|
||||
self.episodes_since_last_encoding = 0
|
||||
self.vcodec = vcodec
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -1211,6 +1221,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index,
|
||||
self.root,
|
||||
self.fps,
|
||||
self.vcodec,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
@@ -1526,7 +1537,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -1542,8 +1553,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -1560,6 +1574,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_writer = None
|
||||
obj.batch_encoding_size = batch_encoding_size
|
||||
obj.episodes_since_last_encoding = 0
|
||||
obj.vcodec = vcodec
|
||||
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
|
||||
@@ -27,6 +27,8 @@ lerobot-record \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--display_data=true
|
||||
# <- Optional: specify video codec (h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# --dataset.vcodec=h264 \
|
||||
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
@@ -165,6 +167,9 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1'.
|
||||
# Use 'h264' for faster encoding on systems where AV1 encoding is CPU-heavy.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@@ -427,6 +432,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
)
|
||||
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
@@ -448,6 +454,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
|
||||
@@ -31,8 +31,10 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
VALID_VIDEO_CODECS,
|
||||
LeRobotDataset,
|
||||
MultiLeRobotDataset,
|
||||
_encode_video_worker,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
@@ -1292,3 +1294,101 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
frame = loaded_dataset[idx]
|
||||
expected_ep = idx // frames_per_episode
|
||||
assert frame["episode_index"].item() == expected_ep
|
||||
|
||||
|
||||
def test_encode_video_worker_forwards_vcodec(tmp_path):
|
||||
"""Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
|
||||
# Create the expected directory structure
|
||||
video_key = "observation.images.laptop"
|
||||
episode_index = 0
|
||||
frame_index = 0
|
||||
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=video_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a dummy image file
|
||||
dummy_img = Image.new("RGB", (64, 64), color="red")
|
||||
dummy_img.save(img_dir / "frame-000000.png")
|
||||
|
||||
# Track what vcodec was passed to encode_video_frames
|
||||
captured_kwargs = {}
|
||||
|
||||
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
# Create a dummy output file so the worker doesn't fail
|
||||
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
|
||||
# Test with h264 codec
|
||||
_encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264")
|
||||
|
||||
assert "vcodec" in captured_kwargs
|
||||
assert captured_kwargs["vcodec"] == "h264"
|
||||
|
||||
|
||||
def test_encode_video_worker_default_vcodec(tmp_path):
|
||||
"""Test that _encode_video_worker uses libsvtav1 as the default codec."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
|
||||
# Create the expected directory structure
|
||||
video_key = "observation.images.laptop"
|
||||
episode_index = 0
|
||||
frame_index = 0
|
||||
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=video_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a dummy image file
|
||||
dummy_img = Image.new("RGB", (64, 64), color="red")
|
||||
dummy_img.save(img_dir / "frame-000000.png")
|
||||
|
||||
# Track what vcodec was passed to encode_video_frames
|
||||
captured_kwargs = {}
|
||||
|
||||
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
# Create a dummy output file so the worker doesn't fail
|
||||
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
|
||||
# Test with default codec (no vcodec specified)
|
||||
_encode_video_worker(video_key, episode_index, tmp_path, fps=30)
|
||||
|
||||
assert "vcodec" in captured_kwargs
|
||||
assert captured_kwargs["vcodec"] == "libsvtav1"
|
||||
|
||||
|
||||
def test_lerobot_dataset_vcodec_validation():
|
||||
"""Test that LeRobotDataset validates the vcodec parameter."""
|
||||
# Test that invalid vcodec raises ValueError
|
||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||
LeRobotDataset.__new__(LeRobotDataset) # bypass __init__ to test validation directly
|
||||
# Actually test via create since it's easier
|
||||
LeRobotDataset.create(
|
||||
repo_id="test/invalid_codec",
|
||||
fps=30,
|
||||
features={"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}},
|
||||
vcodec="invalid_codec",
|
||||
)
|
||||
|
||||
|
||||
def test_valid_video_codecs_constant():
|
||||
"""Test that VALID_VIDEO_CODECS contains the expected codecs."""
|
||||
assert "h264" in VALID_VIDEO_CODECS
|
||||
assert "hevc" in VALID_VIDEO_CODECS
|
||||
assert "libsvtav1" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 3
|
||||
|
||||
Reference in New Issue
Block a user