From 2864caad80a67030bb65eb4ee1f565506cd72ec4 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 28 Mar 2025 17:16:51 +0100 Subject: [PATCH] Adding audio modality in LeRobotDatasets --- src/lerobot/datasets/aggregate.py | 6 +- src/lerobot/datasets/audio_utils.py | 193 ++++++++++ src/lerobot/datasets/lerobot_dataset.py | 333 ++++++++++++++++-- src/lerobot/datasets/utils.py | 15 +- .../v30/convert_dataset_v21_to_v30.py | 8 +- src/lerobot/datasets/video_utils.py | 125 +++---- tests/fixtures/dataset_factories.py | 3 + 7 files changed, 582 insertions(+), 101 deletions(-) create mode 100644 src/lerobot/datasets/audio_utils.py diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 94ffe602e..4a5069fc8 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -41,7 +41,7 @@ from lerobot.datasets.utils import ( write_stats, write_tasks, ) -from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s +from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): @@ -328,7 +328,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu file_index=file_idx, ) - src_duration = get_video_duration_in_s(src_path) + src_duration = get_media_duration_in_s(src_path, media_type="video") dst_key = (chunk_idx, file_idx) if not dst_path.exists(): @@ -367,7 +367,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu current_dst_duration = dst_file_durations.get(dst_key, 0) videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key - concatenate_video_files( + concatenate_media_files( [dst_path, src_path], dst_path, ) diff --git a/src/lerobot/datasets/audio_utils.py b/src/lerobot/datasets/audio_utils.py new file mode 100644 index 000000000..c112e1b8f --- /dev/null +++ b/src/lerobot/datasets/audio_utils.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import av +import torch +import torchaudio +from numpy import ceil + +CHANNELS_LAYOUTS_MAPPING = { + 1: "mono", + 2: "stereo", + 3: "2.1", + 4: "3.1", + 5: "4.1", + 6: "5.1", + 7: "6.1", + 8: "7.1", + 16: "hexadecagonal", + 24: "22.2", +} + + +def decode_audio( + audio_path: Path | str, + timestamps: list[float], + duration: float, + backend: str | None = "torchaudio", +) -> torch.Tensor: + """ + Decodes audio using the specified backend. + Args: + audio_path (Path): Path to the audio file. + timestamps (list[float]): List of timestamps to extract frames. + tolerance_s (float): Allowed deviation in seconds for frame retrieval. + backend (str, optional): Backend to use for decoding. Defaults to "torchaudio". + + Returns: + torch.Tensor: Decoded frames. + + Currently supports torchaudio. + """ + if backend == "torchcodec": + raise NotImplementedError("torchcodec is not yet supported for audio decoding") + elif backend == "torchaudio": + return decode_audio_torchaudio(audio_path, timestamps, duration) + else: + raise ValueError(f"Unsupported video backend: {backend}") + + +def decode_audio_torchaudio( + audio_path: Path | str, + timestamps: list[float], + duration: float, + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + # TODO(CarolinePascal) : add channels selection + audio_path = str(audio_path) + + reader = torchaudio.io.StreamReader(src=audio_path) + audio_sampling_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate + + # TODO(CarolinePascal) : sort timestamps ? + + reader.add_basic_audio_stream( + frames_per_chunk=int(ceil(duration * audio_sampling_rate)), # Too much is better than not enough + buffer_chunk_size=-1, # No dropping frames + ) + + audio_chunks = [] + for ts in timestamps: + reader.seek(ts) # Default to closest audio sample + status = reader.fill_buffer() + if status != 0: + logging.warning("Audio stream reached end of recording before decoding desired timestamps.") + + current_audio_chunk = reader.pop_chunks()[0] + + if log_loaded_timestamps: + logging.info( + f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sampling_rate:.4f}" + ) + + audio_chunks.append(current_audio_chunk) + + audio_chunks = torch.stack(audio_chunks) + # TODO(CarolinePascal) : pytorch format conversion ? + + assert len(timestamps) == len(audio_chunks) + return audio_chunks + + +def encode_audio( + input_path: Path | str, + output_path: Path | str, + codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options + bit_rate: int | None = None, + sample_rate: int | None = None, + log_level: int | None = av.logging.ERROR, + overwrite: bool = False, +) -> None: + """Encodes an audio file using ffmpeg.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=overwrite) + + # Set logging level + if log_level is not None: + # "While less efficient, it is generally preferable to modify logging with Python’s logging" + logging.getLogger("libav").setLevel(log_level) + + # Open input file + with av.open(str(input_path), "r") as input: + input_stream = input.streams.audio[0] # Assuming the first stream is the audio stream to be encoded + + # Define sub-sampling options + if sample_rate is None: + sample_rate = input_stream.rate + + # Create and open output file (overwrite by default) + with av.open(str(output_path), "w") as output: + output_stream = output.add_stream( + codec, rate=sample_rate, layout=CHANNELS_LAYOUTS_MAPPING[input_stream.channels] + ) + + if bit_rate is not None: + output_stream.bit_rate = bit_rate + + # Loop through input WAV packets and encode them + for input_frame in input.decode( + input_stream + ): # This step handles both demuxing and decoding under the hood + packet = output_stream.encode(input_frame) + if packet: + output.mux(packet) + + # Flush the encoder + packet = output_stream.encode() + if packet: + output.mux(packet) + + # Reset logging level + if log_level is not None: + av.logging.restore_default_callback() + + if not output_path.exists(): + raise OSError(f"Audio encoding did not work. File not found: {output_path}.") + + +def get_audio_info(video_path: Path | str) -> dict: + # Set logging level + logging.getLogger("libav").setLevel(av.logging.ERROR) + + # Getting audio stream information + audio_info = {} + with av.open(str(video_path), "r") as audio_file: + try: + audio_stream = audio_file.streams.audio[0] + except IndexError: + # Reset logging level + av.logging.restore_default_callback() + return {"has_audio": False} + + audio_info["audio.channels"] = audio_stream.channels + audio_info["audio.codec"] = audio_stream.codec.canonical_name + # In an ideal loseless case : bit depth x sample rate x channels = bit rate. + # In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied. + audio_info["audio.bit_rate"] = audio_stream.bit_rate + audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second + # In an ideal loseless case : fixed number of bits per sample. + # In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate). + audio_info["audio.bit_depth"] = audio_stream.format.bits + audio_info["audio.channel_layout"] = audio_stream.layout.name + audio_info["has_audio"] = True + + # Reset logging level + av.logging.restore_default_callback() + + return audio_info diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 6798e7fd7..751fe017e 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -33,12 +33,15 @@ import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError +from lerobot.datasets.audio_utils import decode_audio, encode_audio, get_audio_info from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.datasets.image_writer import AsyncImageWriter, write_image from lerobot.datasets.utils import ( + DEFAULT_AUDIO_CHUNK_DURATION, DEFAULT_EPISODES_PATH, DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, + DEFAULT_RAW_AUDIO_PATH, INFO_PATH, _validate_feature_names, check_delta_timestamps, @@ -68,11 +71,12 @@ from lerobot.datasets.utils import ( ) from lerobot.datasets.video_utils import ( VideoFrame, - concatenate_video_files, + concatenate_media_files, decode_video_frames, encode_video_frames, + get_audio_duration_in_s, + get_media_duration_in_s, get_safe_default_codec, - get_video_duration_in_s, get_video_info, ) from lerobot.utils.constants import HF_LEROBOT_HOME @@ -214,6 +218,19 @@ class LeRobotDatasetMetadata: fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) return Path(fpath) + def get_audio_file_path(self, ep_index: int, audio_key: str) -> Path: + if self.episodes is None: + self.episodes = load_episodes(self.root) + if ep_index >= len(self.episodes): + raise IndexError( + f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}" + ) + ep = self.episodes[ep_index] + chunk_idx = ep[f"audio/{audio_key}/chunk_index"] + file_idx = ep[f"audio/{audio_key}/file_index"] + fpath = self.audio_path.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx) + return Path(fpath) + @property def data_path(self) -> str: """Formattable string for the parquet files.""" @@ -224,6 +241,11 @@ class LeRobotDatasetMetadata: """Formattable string for the video files.""" return self.info["video_path"] + @property + def audio_path(self) -> str | None: + """Formattable string for the audio files.""" + return self.info["audio_path"] + @property def robot_type(self) -> str | None: """Robot type used in recording this dataset.""" @@ -254,6 +276,11 @@ class LeRobotDatasetMetadata: """Keys to access visual modalities (regardless of their storage method).""" return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + @property + def audio_keys(self) -> list[str]: + """Keys to access audio modalities.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "audio"] + @property def names(self) -> dict[str, list | dict]: """Names of the various dimensions of vector modalities.""" @@ -294,6 +321,11 @@ class LeRobotDatasetMetadata: """Max size of video file in mega bytes.""" return self.info["video_files_size_in_mb"] + @property + def audio_files_size_in_mb(self) -> int: + """Max size of audio file in mega bytes.""" + return self.info["audio_files_size_in_mb"] + def get_task_index(self, task: str) -> int | None: """ Given a task in natural language, returns its task_index if the task already exists in the dataset, @@ -435,11 +467,26 @@ class LeRobotDatasetMetadata: video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) self.info["features"][key]["info"] = get_video_info(video_path) + def update_audio_info(self, audio_key: str | None = None) -> None: + """ + Warning: this function writes info from first episode audio, implicitly assuming that all audio have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + if audio_key is not None and audio_key not in self.audio_keys: + raise ValueError(f"Audio key {audio_key} not found in dataset") + + audio_keys = [audio_key] if audio_key is not None else self.audio_keys + for key in audio_keys: + if not self.features[key].get("info", None): + audio_path = self.root / self.audio_path.format(audio_key=key, chunk_index=0, file_index=0) + self.info["features"][key]["info"] = get_audio_info(audio_path) + def update_chunk_settings( self, chunks_size: int | None = None, data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, + audio_files_size_in_mb: int | None = None, ) -> None: """Update chunk and file size settings after dataset creation. @@ -451,6 +498,7 @@ class LeRobotDatasetMetadata: chunks_size: Maximum number of files per chunk directory. If None, keeps current value. data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value. video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value. + audio_files_size_in_mb: Maximum size for audio files in MB. If None, keeps current value. """ if chunks_size is not None: if chunks_size <= 0: @@ -467,6 +515,11 @@ class LeRobotDatasetMetadata: raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}") self.info["video_files_size_in_mb"] = video_files_size_in_mb + if audio_files_size_in_mb is not None: + if audio_files_size_in_mb <= 0: + raise ValueError(f"audio_files_size_in_mb must be positive, got {audio_files_size_in_mb}") + self.info["audio_files_size_in_mb"] = audio_files_size_in_mb + # Update the info file on disk write_info(self.info, self.root) @@ -474,12 +527,13 @@ class LeRobotDatasetMetadata: """Get current chunk and file size settings. Returns: - Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb. + Dict containing chunks_size, data_files_size_in_mb, video_files_size_in_mb, and audio_files_size_in_mb. """ return { "chunks_size": self.chunks_size, "data_files_size_in_mb": self.data_files_size_in_mb, "video_files_size_in_mb": self.video_files_size_in_mb, + "audio_files_size_in_mb": self.audio_files_size_in_mb, } def __repr__(self): @@ -506,6 +560,7 @@ class LeRobotDatasetMetadata: chunks_size: int | None = None, data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, + audio_files_size_in_mb: int | None = None, ) -> "LeRobotDatasetMetadata": """Creates metadata for a LeRobotDataset.""" obj = cls.__new__(cls) @@ -529,6 +584,7 @@ class LeRobotDatasetMetadata: chunks_size, data_files_size_in_mb, video_files_size_in_mb, + audio_files_size_in_mb, ) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() @@ -565,6 +621,7 @@ class LeRobotDataset(torch.utils.data.Dataset): force_cache_sync: bool = False, download_videos: bool = True, video_backend: str | None = None, + audio_backend: str | None = None, batch_encoding_size: int = 1, vcodec: str = "libsvtav1", ): @@ -598,6 +655,7 @@ class LeRobotDataset(torch.utils.data.Dataset): task-conditioned training. - hf_dataset (from datasets.Dataset), which will read any values from parquet files. - videos (optional) from which frames are loaded to be synchronous with data from parquet files. + - audio (optional) from which audio is loaded to be synchronous with data from parquet files. A typical LeRobotDataset looks like this from its root path: . @@ -623,19 +681,37 @@ class LeRobotDataset(torch.utils.data.Dataset): │ ├── info.json │ ├── stats.json │ └── tasks.parquet - └── videos - ├── observation.images.laptop + ├── videos + │ ├── observation.images.laptop + │ │ ├── chunk-000 + │ │ │ ├── file-000.mp4 + │ │ │ ├── file-001.mp4 + │ │ │ └── ... + │ │ ├── chunk-001 + │ │ │ └── ... + │ │ └── ... + │ ├── observation.images.phone + │ │ ├── chunk-000 + │ │ │ ├── file-000.mp4 + │ │ │ ├── file-001.mp4 + │ │ │ └── ... + │ │ ├── chunk-001 + │ │ │ └── ... + │ │ └── ... + │ └── ... + └── audio + ├── observation.audio.laptop │ ├── chunk-000 - │ │ ├── file-000.mp4 - │ │ ├── file-001.mp4 + │ │ ├── file-000.m4a + │ │ ├── file-001.m4a │ │ └── ... │ ├── chunk-001 │ │ └── ... │ └── ... - ├── observation.images.phone + ├── observation.audio.phone │ ├── chunk-000 - │ │ ├── file-000.mp4 - │ │ ├── file-001.mp4 + │ │ ├── file-000.m4a + │ │ ├── file-001.m4a │ │ └── ... │ ├── chunk-001 │ │ └── ... @@ -677,6 +753,7 @@ class LeRobotDataset(torch.utils.data.Dataset): True. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. + audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'ffmpeg'. batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos. Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1. vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc', @@ -694,6 +771,9 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else get_safe_default_codec() + self.audio_backend = ( + audio_backend if audio_backend else "ffmpeg" + ) # Waiting for torchcodec release #TODO(CarolinePascal) self.delta_indices = None self.batch_encoding_size = batch_encoding_size self.episodes_since_last_encoding = 0 @@ -864,7 +944,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return hf_dataset def _check_cached_episodes_sufficient(self) -> bool: - """Check if the cached dataset contains all requested episodes and their video files.""" + """Check if the cached dataset contains all requested episodes and their video and audio files.""" if self.hf_dataset is None or len(self.hf_dataset) == 0: return False @@ -892,6 +972,14 @@ class LeRobotDataset(torch.utils.data.Dataset): if not video_path.exists(): return False + # Check if all required audio files exist + if len(self.meta.audio_keys) > 0: + for ep_idx in requested_episodes: + for audio_key in self.meta.audio_keys: + audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key) + if not audio_path.exists(): + return False + return True def create_hf_dataset(self) -> datasets.Dataset: @@ -970,7 +1058,7 @@ class LeRobotDataset(torch.utils.data.Dataset): query_indices: dict[str, list[int]] | None = None, ) -> dict[str, list[float]]: query_timestamps = {} - for key in self.meta.video_keys: + for key in self.meta.video_keys + self.meta.audio_keys: if query_indices is not None and key in query_indices: if self._absolute_to_relative_idx is not None: relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]] @@ -985,7 +1073,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: """ - Query dataset for indices across keys, skipping video keys. + Query dataset for indices across keys, skipping video keys and audio keys. Tries column-first [key][indices] for speed, falls back to row-first. @@ -997,7 +1085,7 @@ class LeRobotDataset(torch.utils.data.Dataset): """ result: dict = {} for key, q_idx in query_indices.items(): - if key in self.meta.video_keys: + if key in self.meta.video_keys or key in self.meta.audio_keys: continue # Map absolute indices to relative indices if needed relative_indices = ( @@ -1032,6 +1120,25 @@ class LeRobotDataset(torch.utils.data.Dataset): return item + # TODO(CarolinePascal): add variable query durations + def _query_audio( + self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int + ) -> dict[str, torch.Tensor]: + ep = self.meta.episodes[ep_idx] + item = {} + for audio_key, query_ts in query_timestamps.items(): + # Episodes are stored sequentially on a single mp4 to reduce the number of files. + # Thus we load the start timestamp of the episode on this mp4 and, + # shift the query timestamp accordingly. + from_timestamp = ep[f"audio/{audio_key}/from_timestamp"] + shifted_query_ts = [from_timestamp + ts for ts in query_ts] + + audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key) + audio_chunk = decode_audio(audio_path, shifted_query_ts, query_duration, self.audio_backend) + item[audio_key] = audio_chunk.squeeze(0) + + return item + def _ensure_hf_dataset_loaded(self): """Lazy load the HF dataset only when needed for reading.""" if self._lazy_loading or self.hf_dataset is None: @@ -1061,11 +1168,12 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, val in query_result.items(): item[key] = val - if len(self.meta.video_keys) > 0: + if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0: current_ts = item["timestamp"].item() query_timestamps = self._get_query_timestamps(current_ts, query_indices) video_frames = self._query_videos(query_timestamps, ep_idx) - item = {**video_frames, **item} + audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx) + item = {**item, **video_frames, **audio_chunks} if self.image_transforms is not None: image_keys = self.meta.camera_keys @@ -1113,6 +1221,10 @@ class LeRobotDataset(torch.utils.data.Dataset): ) return self.root / fpath + def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path: + fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index) + return self.root / fpath + def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path: return self._get_image_file_path(episode_index, image_key, frame_index=0).parent @@ -1211,7 +1323,7 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, ft in self.features.items(): # index, episode_index, task_index are already processed above, and image and video # are processed separately by storing image path and frame info as meta data - if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video", "audio"]: continue episode_buffer[key] = np.stack(episode_buffer[key]) @@ -1221,9 +1333,10 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_metadata = self._save_episode_data(episode_buffer) has_video_keys = len(self.meta.video_keys) > 0 + has_audio_keys = len(self.meta.audio_keys) > 0 use_batched_encoding = self.batch_encoding_size > 1 - if has_video_keys and not use_batched_encoding: + if (has_video_keys or has_audio_keys) and not use_batched_encoding: num_cameras = len(self.meta.video_keys) if parallel_encoding and num_cameras > 1: # TODO(Steven): Ideally we would like to control the number of threads per encoding such that: @@ -1260,21 +1373,30 @@ class LeRobotDataset(torch.utils.data.Dataset): for video_key in self.meta.video_keys: ep_metadata.update(self._save_episode_video(video_key, episode_index)) + #TODO(Caroline): add parallel encoding for audio as well + for audio_key in self.meta.audio_keys: + ep_metadata.update(self._save_episode_audio(audio_key, episode_index)) + # `meta.save_episode` need to be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) - if has_video_keys and use_batched_encoding: + if (has_video_keys or has_audio_keys) and use_batched_encoding: # Check if we should trigger batch encoding self.episodes_since_last_encoding += 1 if self.episodes_since_last_encoding == self.batch_encoding_size: start_ep = self.num_episodes - self.batch_encoding_size end_ep = self.num_episodes - self._batch_save_episode_video(start_ep, end_ep) + if has_video_keys: + self._batch_save_episode_video(start_ep, end_ep) + if has_audio_keys: + self._batch_save_episode_audio(start_ep, end_ep) self.episodes_since_last_encoding = 0 if not episode_data: # Reset episode buffer and clean up temporary images (if not already deleted during video encoding) - self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0) + self.clear_episode_buffer( + delete_images=len(self.meta.image_keys) > 0, delete_audio=len(self.meta.audio_keys) > 0 + ) def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None: """ @@ -1325,7 +1447,70 @@ class LeRobotDataset(torch.utils.data.Dataset): dtype_backend="pyarrow" ) # allows NaN values along with integers + # Save the current episode's audio metadata to the dataframe + audio_ep_metadata = {} + for audio_key in self.meta.audio_keys: + audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx)) + audio_ep_metadata.pop("episode_index") + audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes( + dtype_backend="pyarrow" + ) # allows NaN values along with integers + episode_df = episode_df.combine_first(video_ep_df) + episode_df = episode_df.combine_first(audio_ep_df) + episode_df.to_parquet(episode_df_path) + self.meta.episodes = load_episodes(self.root) + + def _batch_save_episode_audio(self, start_episode: int, end_episode: int | None = None) -> None: + """ + Batch save audio for multiple episodes. + + Args: + start_episode: Starting episode index (inclusive) + end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode. + """ + if end_episode is None: + end_episode = self.num_episodes + + logging.info( + f"Batch encoding {self.batch_encoding_size} audio for episodes {start_episode} to {end_episode - 1}" + ) + + chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"] + file_idx = self.meta.episodes[start_episode]["data/file_index"] + episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + episode_df = pd.read_parquet(episode_df_path) + + for ep_idx in range(start_episode, end_episode): + logging.info(f"Encoding audio for episode {ep_idx}") + + if ( + self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx + or self.meta.episodes[ep_idx]["data/file_index"] != file_idx + ): + # The current episode is in a new chunk or file. + # Save previous episode dataframe and update the Hugging Face dataset by reloading it. + episode_df.to_parquet(episode_df_path) + self.meta.episodes = load_episodes(self.root) + + # Load new episode dataframe + chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"] + file_idx = self.meta.episodes[ep_idx]["data/file_index"] + episode_df_path = self.root / DEFAULT_EPISODES_PATH.format( + chunk_index=chunk_idx, file_index=file_idx + ) + episode_df = pd.read_parquet(episode_df_path) + + # Save the current episode's video metadata to the dataframe + audio_ep_metadata = {} + for audio_key in self.meta.audio_keys: + audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx)) + audio_ep_metadata.pop("episode_index") + audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes( + dtype_backend="pyarrow" + ) # allows NaN values along with integers + + episode_df = episode_df.combine_first(audio_ep_df) episode_df.to_parquet(episode_df_path) self.meta.episodes = load_episodes(self.root) @@ -1436,7 +1621,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_path = temp_path ep_size_in_mb = get_file_size_in_mb(ep_path) - ep_duration_in_s = get_video_duration_in_s(ep_path) + ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video") if ( episode_index == 0 @@ -1482,7 +1667,7 @@ class LeRobotDataset(torch.utils.data.Dataset): latest_duration_in_s = 0.0 else: # Update latest video file - concatenate_video_files( + concatenate_media_files( [latest_path, ep_path], latest_path, ) @@ -1504,7 +1689,79 @@ class LeRobotDataset(torch.utils.data.Dataset): } return metadata - def clear_episode_buffer(self, delete_images: bool = True) -> None: + def _save_episode_audio(self, audio_key: str, episode_index: int) -> dict: + # Encode episode audio into a temporary audio file + ep_path = self._encode_temporary_episode_audio(audio_key, episode_index) + ep_size_in_mb = get_file_size_in_mb(ep_path) + ep_duration_in_s = get_audio_duration_in_s(ep_path) + + if ( + episode_index == 0 + or self.meta.latest_episode is None + or f"audio/{audio_key}/chunk_index" not in self.meta.latest_episode + ): + # Initialize indices for a new dataset made of the first episode data + chunk_idx, file_idx = 0, 0 + if self.meta.episodes is not None and len(self.meta.episodes) > 0: + # It means we are resuming recording, so we need to load the latest episode + # Update the indices to avoid overwriting the latest episode + old_chunk_idx = self.meta.episodes[-1][f"audio/{audio_key}/chunk_index"] + old_file_idx = self.meta.episodes[-1][f"audio/{audio_key}/file_index"] + chunk_idx, file_idx = update_chunk_file_indices( + old_chunk_idx, old_file_idx, self.meta.chunks_size + ) + latest_duration_in_s = 0.0 + new_path = self.root / self.meta.audio_path.format( + audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + else: + # Retrieve information from the latest updated audio file using latest_episode + latest_ep = self.meta.latest_episode + chunk_idx = latest_ep[f"audio/{audio_key}/chunk_index"][0] + file_idx = latest_ep[f"audio/{audio_key}/file_index"][0] + + latest_path = self.root / self.meta.audio_path.format( + audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx + ) + latest_size_in_mb = get_file_size_in_mb(latest_path) + latest_duration_in_s = latest_ep[f"audio/{audio_key}/to_timestamp"][0] + + if latest_size_in_mb + ep_size_in_mb >= self.meta.audio_files_size_in_mb: + # Move temporary episode audio to a new audio file in the dataset + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) + new_path = self.root / self.meta.audio_path.format( + audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + latest_duration_in_s = 0.0 + else: + # Update latest audio file + concatenate_media_files( + [latest_path, ep_path], + latest_path, + ) + + # Remove temporary directory + shutil.rmtree(str(ep_path.parent)) + + # Update audio info (only needed when first episode is encoded since it reads from episode 0) + if episode_index == 0: + self.meta.update_audio_info(audio_key) + write_info(self.meta.info, self.meta.root) # ensure audio info always written properly + + metadata = { + "episode_index": episode_index, + f"audio/{audio_key}/chunk_index": chunk_idx, + f"audio/{audio_key}/file_index": file_idx, + f"audio/{audio_key}/from_timestamp": latest_duration_in_s, + f"audio/{audio_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s, + } + return metadata + + def clear_episode_buffer(self, delete_images: bool = True, delete_audio: bool = True) -> None: # Clean up image files for the current episode buffer if delete_images: # Wait for the async image writer to finish @@ -1518,6 +1775,16 @@ class LeRobotDataset(torch.utils.data.Dataset): if img_dir.is_dir(): shutil.rmtree(img_dir) + # Clean up audio files for the current episode buffer + if delete_audio: + episode_index = self.episode_buffer["episode_index"] + if isinstance(episode_index, np.ndarray): + episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] + for microphone_key in self.meta.microphone_keys: + audio_file = self.root / self.meta.get_audio_file_path(episode_index, microphone_key) + if audio_file.is_file(): + audio_file.unlink() + # Reset the buffer self.episode_buffer = self.create_episode_buffer() @@ -1554,6 +1821,18 @@ class LeRobotDataset(torch.utils.data.Dataset): """ return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec) + def _encode_temporary_episode_audio(self, audio_key: str, episode_index: int) -> Path: + """ + Use ffmpeg to convert raw audio files into m4a audio files. + Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding, + since audio encoding with ffmpeg is already using multithreading. + """ + temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{audio_key}_{episode_index:03d}.m4a" + raw_audio_file = self._get_raw_audio_file_path(episode_index, audio_key) + encode_audio(raw_audio_file, temp_path, overwrite=True) + raw_audio_file.unlink() + return temp_path + @classmethod def create( cls, @@ -1567,6 +1846,7 @@ class LeRobotDataset(torch.utils.data.Dataset): image_writer_processes: int = 0, image_writer_threads: int = 0, video_backend: str | None = None, + audio_backend: str | None = None, batch_encoding_size: int = 1, vcodec: str = "libsvtav1", ) -> "LeRobotDataset": @@ -1611,6 +1891,9 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._lazy_loading = False obj._recorded_frames = 0 obj._writer_closed_for_reading = False + obj.audio_backend = ( + audio_backend if audio_backend is not None else "ffmpeg" + ) # Waiting for torchcodec release #TODO(CarolinePascal) return obj @@ -1631,6 +1914,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): tolerances_s: dict | None = None, download_videos: bool = True, video_backend: str | None = None, + audio_backend: str | None = None, ): super().__init__() self.repo_ids = repo_ids @@ -1648,6 +1932,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): tolerance_s=self.tolerances_s[repo_id], download_videos=download_videos, video_backend=video_backend, + audio_backend=audio_backend, ) for repo_id in repo_ids ] diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index ed678af6e..8bd34809f 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -50,6 +50,7 @@ from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_strin DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file +DEFAULT_AUDIO_FILE_SIZE_IN_MB = 100 # Max size per file INFO_PATH = "meta/info.json" STATS_PATH = "meta/stats.json" @@ -57,13 +58,18 @@ STATS_PATH = "meta/stats.json" EPISODES_DIR = "meta/episodes" DATA_DIR = "data" VIDEO_DIR = "videos" +AUDIO_DIR = "audio" CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" DEFAULT_TASKS_PATH = "meta/tasks.parquet" DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" +DEFAULT_AUDIO_PATH = AUDIO_DIR + "/{audio_key}/" + CHUNK_FILE_PATTERN + ".m4a" DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png" +DEFAULT_RAW_AUDIO_PATH = "raw_audio/{audio_key}/episode_{episode_index:06d}.wav" + +DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds LEGACY_EPISODES_PATH = "meta/episodes.jsonl" LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" @@ -576,7 +582,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: """ hf_features = {} for key, ft in features.items(): - if ft["dtype"] == "video": + if ft["dtype"] == "video" or ft["dtype"] == "audio": continue elif ft["dtype"] == "image": hf_features[key] = datasets.Image() @@ -802,6 +808,7 @@ def create_empty_dataset_info( chunks_size: int | None = None, data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, + audio_files_size_in_mb: int | None = None, ) -> dict: """Create a template dictionary for a new dataset's `info.json`. @@ -811,6 +818,10 @@ def create_empty_dataset_info( features (dict): The LeRobot features dictionary for the dataset. use_videos (bool): Whether the dataset will store videos. robot_type (str | None): The type of robot used, if any. + chunks_size (int | None): The maximum number of files per chunk directory. + data_files_size_in_mb (int | None): The maximum size for data files in MB. + video_files_size_in_mb (int | None): The maximum size for video files in MB. + audio_files_size_in_mb (int | None): The maximum size for audio files in MB. Returns: dict: A dictionary with the initial dataset metadata. @@ -824,10 +835,12 @@ def create_empty_dataset_info( "chunks_size": chunks_size or DEFAULT_CHUNK_SIZE, "data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, "video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, + "audio_files_size_in_mb": audio_files_size_in_mb or DEFAULT_AUDIO_FILE_SIZE_IN_MB, "fps": fps, "splits": {}, "data_path": DEFAULT_DATA_PATH, "video_path": DEFAULT_VIDEO_PATH if use_videos else None, + "audio_path": DEFAULT_AUDIO_PATH, "features": features, } diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 74be6bfa4..f5df0cdec 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -79,7 +79,7 @@ from lerobot.datasets.utils import ( write_stats, write_tasks, ) -from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s +from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s from lerobot.utils.constants import HF_LEROBOT_HOME from lerobot.utils.utils import init_logging @@ -311,12 +311,12 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"): ep_size_in_mb = get_file_size_in_mb(ep_path) - ep_duration_in_s = get_video_duration_in_s(ep_path) + ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video") # Check if adding this episode would exceed the limit if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0: # Size limit would be exceeded, save current accumulation WITHOUT this episode - concatenate_video_files( + concatenate_media_files( paths_to_cat, new_root / DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx), @@ -352,7 +352,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f # Write remaining videos if any if paths_to_cat: - concatenate_video_files( + concatenate_media_files( paths_to_cat, new_root / DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx), diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 84ce13772..77d8932e0 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -397,42 +397,42 @@ def encode_video_frames( raise OSError(f"Video encoding did not work. File not found: {video_path}.") -def concatenate_video_files( - input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True +def concatenate_media_files( + input_media_paths: list[Path | str], output_media_path: Path, overwrite: bool = True ): """ - Concatenate multiple video files into a single video file using pyav. + Concatenate multiple media files (video & audio) into a single media file using pyav. - This function takes a list of video input file paths and concatenates them into a single - output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast + This function takes a list of input media file paths and concatenates them into a single + output media file. It uses ffmpeg's concat demuxer with stream copy mode for fast concatenation without re-encoding. Args: - input_video_paths: Ordered list of input video file paths to concatenate. - output_video_path: Path to the output video file. - overwrite: Whether to overwrite the output video file if it already exists. Default is True. + input_media_paths: Ordered list of input media file paths to concatenate. + output_media_path: Path to the output media file. + overwrite: Whether to overwrite the output media file if it already exists. Default is True. Note: - - Creates a temporary directory for intermediate files that is cleaned up after use. - - Uses ffmpeg's concat demuxer which requires all input videos to have the same + - Creates a temporary .ffconcat file and container audio/video file that are cleaned up after use. + - Uses ffmpeg's concat demuxer which requires all input media files to have the same codec, resolution, and frame rate for proper concatenation. """ - output_video_path = Path(output_video_path) + output_media_path = Path(output_media_path) - if output_video_path.exists() and not overwrite: - logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.") + if output_media_path.exists() and not overwrite: + logging.warning(f"Media file already exists: {output_media_path}. Skipping concatenation.") return - output_video_path.parent.mkdir(parents=True, exist_ok=True) + output_media_path.parent.mkdir(parents=True, exist_ok=True) - if len(input_video_paths) == 0: - raise FileNotFoundError("No input video paths provided.") + if len(input_media_paths) == 0: + raise FileNotFoundError("No input media paths provided.") - # Create a temporary .ffconcat file to list the input video paths + # Create a temporary .ffconcat file to list the input media paths with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file: tmp_concatenate_file.write("ffconcat version 1.0\n") - for input_path in input_video_paths: + for input_path in input_media_paths: tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n") tmp_concatenate_file.flush() tmp_concatenate_path = tmp_concatenate_file.name @@ -442,11 +442,11 @@ def concatenate_video_files( tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"} ) # safe = 0 allows absolute paths as well as relative paths - with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file: - tmp_output_video_path = tmp_named_file.name + with tempfile.NamedTemporaryFile(suffix=output_media_path.suffix, delete=False) as tmp_named_file: + tmp_output_media_path = tmp_named_file.name output_container = av.open( - tmp_output_video_path, mode="w", options={"movflags": "faststart"} + tmp_output_media_path, mode="w", options={"movflags": "faststart"} ) # faststart is to move the metadata to the beginning of the file to speed up loading # Replicate input streams in output container @@ -476,7 +476,7 @@ def concatenate_video_files( input_container.close() output_container.close() - shutil.move(tmp_output_video_path, output_video_path) + shutil.move(tmp_output_media_path, output_media_path) Path(tmp_concatenate_path).unlink() @@ -512,38 +512,6 @@ with warnings.catch_warnings(): register_feature(VideoFrame, "VideoFrame") -def get_audio_info(video_path: Path | str) -> dict: - # Set logging level - logging.getLogger("libav").setLevel(av.logging.ERROR) - - # Getting audio stream information - audio_info = {} - with av.open(str(video_path), "r") as audio_file: - try: - audio_stream = audio_file.streams.audio[0] - except IndexError: - # Reset logging level - av.logging.restore_default_callback() - return {"has_audio": False} - - audio_info["audio.channels"] = audio_stream.channels - audio_info["audio.codec"] = audio_stream.codec.canonical_name - # In an ideal loseless case : bit depth x sample rate x channels = bit rate. - # In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied. - audio_info["audio.bit_rate"] = audio_stream.bit_rate - audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second - # In an ideal loseless case : fixed number of bits per sample. - # In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate). - audio_info["audio.bit_depth"] = audio_stream.format.bits - audio_info["audio.channel_layout"] = audio_stream.layout.name - audio_info["has_audio"] = True - - # Reset logging level - av.logging.restore_default_callback() - - return audio_info - - def get_video_info(video_path: Path | str) -> dict: # Set logging level logging.getLogger("libav").setLevel(av.logging.ERROR) @@ -573,9 +541,6 @@ def get_video_info(video_path: Path | str) -> dict: # Reset logging level av.logging.restore_default_callback() - # Adding audio stream information - video_info.update(**get_audio_info(video_path)) - return video_info @@ -590,22 +555,22 @@ def get_video_pixel_channels(pix_fmt: str) -> int: raise ValueError("Unknown format") -def get_video_duration_in_s(video_path: Path | str) -> float: +def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -> float: """ - Get the duration of a video file in seconds using PyAV. + Get the duration of a media file (video & audio) in seconds using PyAV. Args: - video_path: Path to the video file. + media_path: Path to the media file. Returns: - Duration of the video in seconds. + Duration of the media file in seconds. """ - with av.open(str(video_path)) as container: - # Get the first video stream - video_stream = container.streams.video[0] + with av.open(str(media_path)) as container: + # Get the first stream + stream = container.streams.video[0] if media_type == "video" else container.streams.audio[0] # Calculate duration: stream.duration * stream.time_base gives duration in seconds - if video_stream.duration is not None: - duration = float(video_stream.duration * video_stream.time_base) + if stream.duration is not None: + duration = float(stream.duration * stream.time_base) else: # Fallback to container duration if stream duration is not available duration = float(container.duration / av.time_base) @@ -614,12 +579,12 @@ def get_video_duration_in_s(video_path: Path | str) -> float: class VideoEncodingManager: """ - Context manager that ensures proper video encoding and data cleanup even if exceptions occur. + Context manager that ensures proper video and audio encoding and data cleanup even if exceptions occur. This manager handles: - Batch encoding for any remaining episodes when recording interrupted - - Cleaning up temporary image files from interrupted episodes - - Removing empty image directories + - Cleaning up temporary image and audio files from interrupted episodes + - Removing empty image and audio directories Args: dataset: The LeRobotDataset instance @@ -646,6 +611,7 @@ class VideoEncodingManager: f"from episode {start_ep} to {end_ep - 1}" ) self.dataset._batch_save_episode_video(start_ep, end_ep) + self.dataset._batch_save_episode_audio(start_ep, end_ep) # Finalize the dataset to properly close all writers self.dataset.finalize() @@ -662,6 +628,15 @@ class VideoEncodingManager: f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}" ) shutil.rmtree(img_dir) + for key in self.dataset.meta.audio_keys: + audio_file = self.dataset._get_raw_audio_file_path( + episode_index=interrupted_episode_index, audio_key=key + ) + if audio_file.exists(): + logging.debug( + f"Cleaning up interrupted episode audio for episode {interrupted_episode_index}, microphone {key}" + ) + audio_file.unlink() # Clean up any remaining images directory if it's empty img_dir = self.dataset.root / "images" @@ -675,4 +650,16 @@ class VideoEncodingManager: else: logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") + # Clean up any remaining audio directory if it's empty + audio_dir = self.dataset.root / "raw_audio" + # Check for any remaining WAV files + wav_files = list(audio_dir.rglob("*.wav")) + if len(wav_files) == 0: + # Only remove the raw_audio directory if no WAV files remain + if audio_dir.exists(): + shutil.rmtree(audio_dir) + logging.debug("Cleaned up empty audio directory") + else: + logging.debug(f"Audio directory is not empty, containing {len(wav_files)} WAV files") + return False # Don't suppress the original exception diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index c33fdcb72..e98e626e2 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -28,6 +28,7 @@ from datasets import Dataset from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.datasets.utils import ( + DEFAULT_AUDIO_PATH, DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, @@ -162,6 +163,7 @@ def info_factory(features_factory): video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB, data_path: str = DEFAULT_DATA_PATH, video_path: str = DEFAULT_VIDEO_PATH, + audio_path: str = DEFAULT_AUDIO_PATH, motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, use_videos: bool = True, @@ -181,6 +183,7 @@ def info_factory(features_factory): "splits": {}, "data_path": data_path, "video_path": video_path if use_videos else None, + "audio_path": audio_path, "features": features, }