From 99eb0bbafcde8aa84a32db19810b92746a549196 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 11 Apr 2025 13:43:17 +0200 Subject: [PATCH] Adding last missing audio features in LeRobotDataset --- src/lerobot/datasets/lerobot_dataset.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index c46ca3034..221e160e8 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -621,6 +621,7 @@ class LeRobotDataset(torch.utils.data.Dataset): revision: str | None = None, force_cache_sync: bool = False, download_videos: bool = True, + download_audio: bool = True, video_backend: str | None = None, audio_backend: str | None = None, batch_encoding_size: int = 1, @@ -752,9 +753,10 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos (bool, optional): Flag to download the videos. Note that when set to True but the video files are already present on local disk, they won't be downloaded again. Defaults to True. + download_audio (bool, optional): Flag to download the audio. Defaults to True. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. - audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'. + audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'. batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos. Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1. vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc', @@ -847,6 +849,7 @@ class LeRobotDataset(torch.utils.data.Dataset): license: str | None = "apache-2.0", tag_version: bool = True, push_videos: bool = True, + push_audio: bool = True, private: bool = False, allow_patterns: list[str] | str | None = None, upload_large_folder: bool = False, @@ -855,6 +858,8 @@ class LeRobotDataset(torch.utils.data.Dataset): ignore_patterns = ["images/"] if not push_videos: ignore_patterns.append("videos/") + if not push_audio: + ignore_patterns.append("audio/") hub_api = HfApi() hub_api.create_repo( @@ -909,7 +914,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ignore_patterns=ignore_patterns, ) - def download(self, download_videos: bool = True) -> None: + def download(self, download_videos: bool = True, download_audio: bool = True) -> None: """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present @@ -917,8 +922,12 @@ class LeRobotDataset(torch.utils.data.Dataset): """ # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads - ignore_patterns = None if download_videos else "videos/" files = None + ignore_patterns = [] + if not download_videos: + ignore_patterns.append("videos/") + if not download_audio: + ignore_patterns.append("audio/") if self.episodes is not None: files = self.get_episodes_file_paths() self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) @@ -933,6 +942,15 @@ class LeRobotDataset(torch.utils.data.Dataset): for ep_idx in episodes ] fpaths += video_files + + if len(self.meta.audio_keys) > 0: + audio_files = [ + str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key)) + for audio_key in self.meta.audio_keys + for ep_idx in episodes + ] + fpaths += audio_files + # episodes are stored in the same files, so we return unique paths only fpaths = list(set(fpaths)) return fpaths