mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
feat(datasets): expose video codec option for dataset recording (#2771)
* expose codec options + add tests * pre-commit run -a
This commit is contained in:
@@ -78,6 +78,7 @@ from lerobot.datasets.video_utils import (
|
|||||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||||
|
|
||||||
CODEBASE_VERSION = "v3.0"
|
CODEBASE_VERSION = "v3.0"
|
||||||
|
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDatasetMetadata:
|
class LeRobotDatasetMetadata:
|
||||||
@@ -540,11 +541,13 @@ class LeRobotDatasetMetadata:
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path:
|
def _encode_video_worker(
|
||||||
|
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
|
||||||
|
) -> Path:
|
||||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||||
img_dir = (root / fpath).parent
|
img_dir = (root / fpath).parent
|
||||||
encode_video_frames(img_dir, temp_path, fps, overwrite=True)
|
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
|
||||||
shutil.rmtree(img_dir)
|
shutil.rmtree(img_dir)
|
||||||
return temp_path
|
return temp_path
|
||||||
|
|
||||||
@@ -563,6 +566,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
batch_encoding_size: int = 1,
|
batch_encoding_size: int = 1,
|
||||||
|
vcodec: str = "libsvtav1",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||||
@@ -675,8 +679,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||||
|
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||||
|
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
|
||||||
|
encoding is CPU-heavy.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if vcodec not in VALID_VIDEO_CODECS:
|
||||||
|
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
@@ -688,6 +697,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
self.batch_encoding_size = batch_encoding_size
|
self.batch_encoding_size = batch_encoding_size
|
||||||
self.episodes_since_last_encoding = 0
|
self.episodes_since_last_encoding = 0
|
||||||
|
self.vcodec = vcodec
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
self.image_writer = None
|
self.image_writer = None
|
||||||
@@ -1211,6 +1221,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
episode_index,
|
episode_index,
|
||||||
self.root,
|
self.root,
|
||||||
self.fps,
|
self.fps,
|
||||||
|
self.vcodec,
|
||||||
): video_key
|
): video_key
|
||||||
for video_key in self.meta.video_keys
|
for video_key in self.meta.video_keys
|
||||||
}
|
}
|
||||||
@@ -1526,7 +1537,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
since video encoding with ffmpeg is already using multithreading.
|
since video encoding with ffmpeg is already using multithreading.
|
||||||
"""
|
"""
|
||||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
|
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@@ -1542,8 +1553,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
image_writer_threads: int = 0,
|
image_writer_threads: int = 0,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
batch_encoding_size: int = 1,
|
batch_encoding_size: int = 1,
|
||||||
|
vcodec: str = "libsvtav1",
|
||||||
) -> "LeRobotDataset":
|
) -> "LeRobotDataset":
|
||||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||||
|
if vcodec not in VALID_VIDEO_CODECS:
|
||||||
|
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
obj.meta = LeRobotDatasetMetadata.create(
|
obj.meta = LeRobotDatasetMetadata.create(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
@@ -1560,6 +1574,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
obj.image_writer = None
|
obj.image_writer = None
|
||||||
obj.batch_encoding_size = batch_encoding_size
|
obj.batch_encoding_size = batch_encoding_size
|
||||||
obj.episodes_since_last_encoding = 0
|
obj.episodes_since_last_encoding = 0
|
||||||
|
obj.vcodec = vcodec
|
||||||
|
|
||||||
if image_writer_processes or image_writer_threads:
|
if image_writer_processes or image_writer_threads:
|
||||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ lerobot-record \
|
|||||||
--dataset.num_episodes=2 \
|
--dataset.num_episodes=2 \
|
||||||
--dataset.single_task="Grab the cube" \
|
--dataset.single_task="Grab the cube" \
|
||||||
--display_data=true
|
--display_data=true
|
||||||
|
# <- Optional: specify video codec (h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||||
|
# --dataset.vcodec=h264 \
|
||||||
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
|
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
|
||||||
# --teleop.type=so100_leader \
|
# --teleop.type=so100_leader \
|
||||||
# --teleop.port=/dev/tty.usbmodem58760431551 \
|
# --teleop.port=/dev/tty.usbmodem58760431551 \
|
||||||
@@ -165,6 +167,9 @@ class DatasetRecordConfig:
|
|||||||
# Number of episodes to record before batch encoding videos
|
# Number of episodes to record before batch encoding videos
|
||||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||||
video_encoding_batch_size: int = 1
|
video_encoding_batch_size: int = 1
|
||||||
|
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1'.
|
||||||
|
# Use 'h264' for faster encoding on systems where AV1 encoding is CPU-heavy.
|
||||||
|
vcodec: str = "libsvtav1"
|
||||||
# Rename map for the observation to override the image and state keys
|
# Rename map for the observation to override the image and state keys
|
||||||
rename_map: dict[str, str] = field(default_factory=dict)
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
@@ -427,6 +432,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
cfg.dataset.repo_id,
|
cfg.dataset.repo_id,
|
||||||
root=cfg.dataset.root,
|
root=cfg.dataset.root,
|
||||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||||
|
vcodec=cfg.dataset.vcodec,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||||
@@ -448,6 +454,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||||
|
vcodec=cfg.dataset.vcodec,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ from lerobot.configs.train import TrainPipelineConfig
|
|||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets.factory import make_dataset
|
||||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||||
from lerobot.datasets.lerobot_dataset import (
|
from lerobot.datasets.lerobot_dataset import (
|
||||||
|
VALID_VIDEO_CODECS,
|
||||||
LeRobotDataset,
|
LeRobotDataset,
|
||||||
MultiLeRobotDataset,
|
MultiLeRobotDataset,
|
||||||
|
_encode_video_worker,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
@@ -1292,3 +1294,101 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
|||||||
frame = loaded_dataset[idx]
|
frame = loaded_dataset[idx]
|
||||||
expected_ep = idx // frames_per_episode
|
expected_ep = idx // frames_per_episode
|
||||||
assert frame["episode_index"].item() == expected_ep
|
assert frame["episode_index"].item() == expected_ep
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_video_worker_forwards_vcodec(tmp_path):
|
||||||
|
"""Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||||
|
|
||||||
|
# Create the expected directory structure
|
||||||
|
video_key = "observation.images.laptop"
|
||||||
|
episode_index = 0
|
||||||
|
frame_index = 0
|
||||||
|
|
||||||
|
fpath = DEFAULT_IMAGE_PATH.format(
|
||||||
|
image_key=video_key, episode_index=episode_index, frame_index=frame_index
|
||||||
|
)
|
||||||
|
img_dir = tmp_path / Path(fpath).parent
|
||||||
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create a dummy image file
|
||||||
|
dummy_img = Image.new("RGB", (64, 64), color="red")
|
||||||
|
dummy_img.save(img_dir / "frame-000000.png")
|
||||||
|
|
||||||
|
# Track what vcodec was passed to encode_video_frames
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
# Create a dummy output file so the worker doesn't fail
|
||||||
|
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
Path(video_path).touch()
|
||||||
|
|
||||||
|
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
|
||||||
|
# Test with h264 codec
|
||||||
|
_encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264")
|
||||||
|
|
||||||
|
assert "vcodec" in captured_kwargs
|
||||||
|
assert captured_kwargs["vcodec"] == "h264"
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_video_worker_default_vcodec(tmp_path):
|
||||||
|
"""Test that _encode_video_worker uses libsvtav1 as the default codec."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||||
|
|
||||||
|
# Create the expected directory structure
|
||||||
|
video_key = "observation.images.laptop"
|
||||||
|
episode_index = 0
|
||||||
|
frame_index = 0
|
||||||
|
|
||||||
|
fpath = DEFAULT_IMAGE_PATH.format(
|
||||||
|
image_key=video_key, episode_index=episode_index, frame_index=frame_index
|
||||||
|
)
|
||||||
|
img_dir = tmp_path / Path(fpath).parent
|
||||||
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create a dummy image file
|
||||||
|
dummy_img = Image.new("RGB", (64, 64), color="red")
|
||||||
|
dummy_img.save(img_dir / "frame-000000.png")
|
||||||
|
|
||||||
|
# Track what vcodec was passed to encode_video_frames
|
||||||
|
captured_kwargs = {}
|
||||||
|
|
||||||
|
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
# Create a dummy output file so the worker doesn't fail
|
||||||
|
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
Path(video_path).touch()
|
||||||
|
|
||||||
|
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
|
||||||
|
# Test with default codec (no vcodec specified)
|
||||||
|
_encode_video_worker(video_key, episode_index, tmp_path, fps=30)
|
||||||
|
|
||||||
|
assert "vcodec" in captured_kwargs
|
||||||
|
assert captured_kwargs["vcodec"] == "libsvtav1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lerobot_dataset_vcodec_validation():
|
||||||
|
"""Test that LeRobotDataset validates the vcodec parameter."""
|
||||||
|
# Test that invalid vcodec raises ValueError
|
||||||
|
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||||
|
LeRobotDataset.__new__(LeRobotDataset) # bypass __init__ to test validation directly
|
||||||
|
# Actually test via create since it's easier
|
||||||
|
LeRobotDataset.create(
|
||||||
|
repo_id="test/invalid_codec",
|
||||||
|
fps=30,
|
||||||
|
features={"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}},
|
||||||
|
vcodec="invalid_codec",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_video_codecs_constant():
|
||||||
|
"""Test that VALID_VIDEO_CODECS contains the expected codecs."""
|
||||||
|
assert "h264" in VALID_VIDEO_CODECS
|
||||||
|
assert "hevc" in VALID_VIDEO_CODECS
|
||||||
|
assert "libsvtav1" in VALID_VIDEO_CODECS
|
||||||
|
assert len(VALID_VIDEO_CODECS) == 3
|
||||||
|
|||||||
Reference in New Issue
Block a user