mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-02 23:57:24 +00:00
Adding last missing audio features in LeRobotDataset
This commit is contained in:
@@ -621,6 +621,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
download_audio: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
@@ -752,9 +753,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
download_audio (bool, optional): Flag to download the audio. Defaults to True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'.
|
||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg' decoder used by 'torchaudio'.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
@@ -847,6 +849,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
push_audio: bool = True,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
@@ -855,6 +858,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not push_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
@@ -909,7 +914,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
def download(self, download_videos: bool = True) -> None:
|
||||
def download(self, download_videos: bool = True, download_audio: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
@@ -917,8 +922,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
files = None
|
||||
ignore_patterns = []
|
||||
if not download_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not download_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
if self.episodes is not None:
|
||||
files = self.get_episodes_file_paths()
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
@@ -933,6 +942,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
audio_files = [
|
||||
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
||||
for audio_key in self.meta.audio_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += audio_files
|
||||
|
||||
# episodes are stored in the same files, so we return unique paths only
|
||||
fpaths = list(set(fpaths))
|
||||
return fpaths
|
||||
|
||||
Reference in New Issue
Block a user