Adding last missing audio features in LeRobotDataset

This commit is contained in:
CarolinePascal
2025-04-11 13:43:17 +02:00
parent 16de8b3f19
commit 99eb0bbafc
+21 -3
View File
@@ -621,6 +621,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,
download_audio: bool = True,
video_backend: str | None = None,
audio_backend: str | None = None,
batch_encoding_size: int = 1,
@@ -752,9 +753,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
download_audio (bool, optional): Flag to download the audio. Defaults to True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'.
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
@@ -847,6 +849,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
push_audio: bool = True,
private: bool = False,
allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False,
@@ -855,6 +858,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")
if not push_audio:
ignore_patterns.append("audio/")
hub_api = HfApi()
hub_api.create_repo(
@@ -909,7 +914,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ignore_patterns=ignore_patterns,
)
def download(self, download_videos: bool = True) -> None:
def download(self, download_videos: bool = True, download_audio: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
@@ -917,8 +922,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
ignore_patterns = None if download_videos else "videos/"
files = None
ignore_patterns = []
if not download_videos:
ignore_patterns.append("videos/")
if not download_audio:
ignore_patterns.append("audio/")
if self.episodes is not None:
files = self.get_episodes_file_paths()
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
@@ -933,6 +942,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
for ep_idx in episodes
]
fpaths += video_files
if len(self.meta.audio_keys) > 0:
audio_files = [
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
for audio_key in self.meta.audio_keys
for ep_idx in episodes
]
fpaths += audio_files
# episodes are stored in the same files, so we return unique paths only
fpaths = list(set(fpaths))
return fpaths