From 00536c6c5befae023797ca46b5719785043a5695 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 4 Apr 2025 19:48:57 +0200 Subject: [PATCH] Adding missing features for audio frames verification and stats --- src/lerobot/datasets/audio_utils.py | 6 +++--- src/lerobot/datasets/compute_stats.py | 13 ++++++++++++- src/lerobot/datasets/lerobot_dataset.py | 6 ++++++ src/lerobot/datasets/utils.py | 25 +++++++++++++++++++++++++ src/lerobot/microphones/microphone.py | 2 +- 5 files changed, 47 insertions(+), 5 deletions(-) diff --git a/src/lerobot/datasets/audio_utils.py b/src/lerobot/datasets/audio_utils.py index 69101304b..5cc0b83f8 100644 --- a/src/lerobot/datasets/audio_utils.py +++ b/src/lerobot/datasets/audio_utils.py @@ -78,9 +78,9 @@ def decode_audio_torchaudio( # TODO(CarolinePascal) : sort timestamps ? reader.add_basic_audio_stream( - frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough - buffer_chunk_size = -1, #No dropping frames - format = "fltp", #Format as float32 + frames_per_chunk=int(ceil(duration * audio_sampling_rate)), # Too much is better than not enough + buffer_chunk_size=-1, # No dropping frames + format="fltp", # Format as float32 ) audio_chunks = [] diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index 61e174d5c..bba1206a5 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -15,7 +15,7 @@ # limitations under the License. import numpy as np -from lerobot.datasets.utils import load_image_as_numpy +from lerobot.datasets.utils import load_audio, load_image_as_numpy DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99] @@ -245,6 +245,13 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images +def sample_audio(audio_path: str) -> np.ndarray: + data = load_audio(audio_path) + sampled_indices = sample_indices(len(data)) + + return data[sampled_indices] + + def _reshape_stats_by_axis( stats: dict[str, np.ndarray], axis: int | tuple[int, ...] | None, @@ -512,6 +519,10 @@ def compute_episode_stats( ep_ft_array = sample_images(data) axes_to_reduce = (0, 2, 3) keepdims = True + elif features[key]["dtype"] == "audio": + ep_ft_array = sample_audio(data[0]) + axes_to_reduce = 0 + keepdims = True else: ep_ft_array = data axes_to_reduce = 0 diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 855398689..aa1fa6a0a 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1278,6 +1278,12 @@ class LeRobotDataset(torch.utils.data.Dataset): compress_level = 1 if self.features[key]["dtype"] == "video" else 6 self._save_image(frame[key], img_path, compress_level) self.episode_buffer[key].append(str(img_path)) + elif self.features[key]["dtype"] == "audio": + if frame_index == 0: + audio_path = self._get_raw_audio_file_path( + episode_index=self.episode_buffer["episode_index"], audio_key=key + ) + self.episode_buffer[key].append(str(audio_path)) else: self.episode_buffer[key].append(frame[key]) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 8bd34809f..8ace0de48 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -36,6 +36,7 @@ from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError from PIL import Image as PILImage +from soundfile import read from torchvision import transforms from lerobot.configs.types import FeatureType, PolicyFeature @@ -414,6 +415,11 @@ def load_image_as_numpy( return img_array +def load_audio(fpath: str | Path) -> np.ndarray: + audio_data, _ = read(fpath, dtype="float32") + return audio_data + + def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: """Convert a batch from a Hugging Face dataset to torch tensors. @@ -1064,6 +1070,8 @@ def validate_feature_dtype_and_shape( return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) elif expected_dtype in ["image", "video"]: return validate_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "audio": + return validate_feature_audio(name, expected_shape, value) elif expected_dtype == "string": return validate_feature_string(name, value) else: @@ -1130,6 +1138,23 @@ def validate_feature_image_or_video( return error_message +def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray): + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c = expected_shape + if len(actual_shape) != 2 or ( + actual_shape[-1] != c[-1] and actual_shape[0] != c[0] + ): # The number of frames might be different + error_message += ( + f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n" + ) + else: + error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n" + + return error_message + + def validate_feature_string(name: str, value: str) -> str: """Validate a feature that is expected to be a string. diff --git a/src/lerobot/microphones/microphone.py b/src/lerobot/microphones/microphone.py index e8c8045b2..c9eaf4b9a 100644 --- a/src/lerobot/microphones/microphone.py +++ b/src/lerobot/microphones/microphone.py @@ -236,7 +236,7 @@ class Microphone: with self.read_queue.mutex: self.read_queue.queue.clear() # self.read_queue.all_tasks_done.notify_all() - audio_readings = np.array(audio_readings).reshape(-1, len(self.channels)) + audio_readings = np.array(audio_readings, dtype=np.float32).reshape(-1, len(self.channels)) return audio_readings