diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index d9d4b22d0..6da5eb49e 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -28,6 +28,7 @@ import pandas as pd import PIL.Image import pyarrow as pa import pyarrow.parquet as pq +from concurrent.futures import ProcessPoolExecutor import torch import torch.utils from huggingface_hub import HfApi, snapshot_download @@ -1199,6 +1200,9 @@ class LeRobotDataset(torch.utils.data.Dataset): use_batched_encoding = self.batch_encoding_size > 1 if has_video_keys and not use_batched_encoding: + video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index) + for (video_key, video_path) in zip(self.meta.video_keys, video_paths): + ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path)) num_cameras = len(self.meta.video_keys) if parallel_encoding and num_cameras > 1: # TODO(Steven): Ideally we would like to control the number of threads per encoding such that: @@ -1397,6 +1401,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return metadata + def _save_episode_video(self, video_key: str, episode_index: int, video_path: str | Path | None = None) -> dict: def _save_episode_video( self, video_key: str, @@ -1404,6 +1409,10 @@ class LeRobotDataset(torch.utils.data.Dataset): temp_path: Path | None = None, ) -> dict: # Encode episode frames into a temporary video + if video_path is None: + ep_path = self._encode_temporary_episode_video(video_key, episode_index) + else: + ep_path = video_path if temp_path is None: ep_path = self._encode_temporary_episode_video(video_key, episode_index) else: @@ -1528,6 +1537,22 @@ class LeRobotDataset(torch.utils.data.Dataset): """ return _encode_video_worker(video_key, episode_index, self.root, self.fps) + def _encode_multiple_temporary_episode_videos(self, video_keys, episode_index): + temp_paths = [] + img_dirs = [] + for video_key in video_keys: + temp_paths.append(Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4") + img_dirs.append(self._get_image_file_dir(episode_index, video_key)) + fps = [self.fps]*len(video_keys) + + with ProcessPoolExecutor() as executor: + executor.map(encode_video_frames,img_dirs,temp_paths,fps) + + for img_dir in img_dirs: + shutil.rmtree(img_dir) + + return temp_paths + @classmethod def create( cls, diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 8a8a8df60..c7945e6e7 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -310,7 +310,7 @@ def encode_video_frames( crf: int | None = 30, fast_decode: int = 0, log_level: int | None = av.logging.ERROR, - overwrite: bool = False, + overwrite: bool = True, preset: int | None = None, ) -> None: """More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""