Adding missing features for audio frames verification and stats

This commit is contained in:
CarolinePascal
2025-04-04 19:48:57 +02:00
parent cdd3a859ef
commit 00536c6c5b
5 changed files with 47 additions and 5 deletions
+3 -3
View File
@@ -78,9 +78,9 @@ def decode_audio_torchaudio(
# TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream(
frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough
buffer_chunk_size = -1, #No dropping frames
format = "fltp", #Format as float32
frames_per_chunk=int(ceil(duration * audio_sampling_rate)), # Too much is better than not enough
buffer_chunk_size=-1, # No dropping frames
format="fltp", # Format as float32
)
audio_chunks = []
+12 -1
View File
@@ -15,7 +15,7 @@
# limitations under the License.
import numpy as np
from lerobot.datasets.utils import load_image_as_numpy
from lerobot.datasets.utils import load_audio, load_image_as_numpy
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
@@ -245,6 +245,13 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
def sample_audio(audio_path: str) -> np.ndarray:
data = load_audio(audio_path)
sampled_indices = sample_indices(len(data))
return data[sampled_indices]
def _reshape_stats_by_axis(
stats: dict[str, np.ndarray],
axis: int | tuple[int, ...] | None,
@@ -512,6 +519,10 @@ def compute_episode_stats(
ep_ft_array = sample_images(data)
axes_to_reduce = (0, 2, 3)
keepdims = True
elif features[key]["dtype"] == "audio":
ep_ft_array = sample_audio(data[0])
axes_to_reduce = 0
keepdims = True
else:
ep_ft_array = data
axes_to_reduce = 0
+6
View File
@@ -1278,6 +1278,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
self._save_image(frame[key], img_path, compress_level)
self.episode_buffer[key].append(str(img_path))
elif self.features[key]["dtype"] == "audio":
if frame_index == 0:
audio_path = self._get_raw_audio_file_path(
episode_index=self.episode_buffer["episode_index"], audio_key=key
)
self.episode_buffer[key].append(str(audio_path))
else:
self.episode_buffer[key].append(frame[key])
+25
View File
@@ -36,6 +36,7 @@ from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from soundfile import read
from torchvision import transforms
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -414,6 +415,11 @@ def load_image_as_numpy(
return img_array
def load_audio(fpath: str | Path) -> np.ndarray:
audio_data, _ = read(fpath, dtype="float32")
return audio_data
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.
@@ -1064,6 +1070,8 @@ def validate_feature_dtype_and_shape(
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "audio":
return validate_feature_audio(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
@@ -1130,6 +1138,23 @@ def validate_feature_image_or_video(
return error_message
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c = expected_shape
if len(actual_shape) != 2 or (
actual_shape[-1] != c[-1] and actual_shape[0] != c[0]
): # The number of frames might be different
error_message += (
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n"
)
else:
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_string(name: str, value: str) -> str:
"""Validate a feature that is expected to be a string.
+1 -1
View File
@@ -236,7 +236,7 @@ class Microphone:
with self.read_queue.mutex:
self.read_queue.queue.clear()
# self.read_queue.all_tasks_done.notify_all()
audio_readings = np.array(audio_readings).reshape(-1, len(self.channels))
audio_readings = np.array(audio_readings, dtype=np.float32).reshape(-1, len(self.channels))
return audio_readings