From e4dd00c8f5870cca22c4f34d6d240cce5c57e457 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Tue, 22 Apr 2025 17:06:10 +0200 Subject: [PATCH] fix(audio feature shape): fixing audio feature shape ordering (frames first, channels second) --- src/lerobot/datasets/utils.py | 6 ++---- tests/datasets/test_datasets.py | 2 +- tests/fixtures/constants.py | 4 ++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index ed2b77233..c75aef4ef 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -1143,11 +1143,9 @@ def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarr if isinstance(value, np.ndarray): actual_shape = value.shape c = expected_shape - if len(actual_shape) != 2 or ( - actual_shape[-1] != c[-1] and actual_shape[0] != c[0] - ): # The number of frames might be different + if len(actual_shape) != 2 or actual_shape[-1] != c[-1]: # The number of frames might be different error_message += ( - f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c,)}'.\n" + f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{c}'.\n" ) else: error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n" diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index e5ee05cee..a443c1899 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -82,7 +82,7 @@ def audio_dataset(tmp_path, empty_lerobot_dataset_factory): features = { "audio": { "dtype": "audio", - "shape": (DUMMY_AUDIO_CHANNELS,), + "shape": (1, DUMMY_AUDIO_CHANNELS), "names": [ "channels", ], diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 017f5e54a..80388c12d 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -41,8 +41,8 @@ DUMMY_VIDEO_INFO = { "has_audio": False, } DUMMY_MICROPHONE_FEATURES = { - "laptop": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None}, - "phone": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None}, + "laptop": {"dtype": "audio", "shape": (1, 2), "names": ["channels"], "info": None}, + "phone": {"dtype": "audio", "shape": (1, 2), "names": ["channels"], "info": None}, } DEFAULT_SAMPLE_RATE = 48000 DUMMY_AUDIO_CHANNELS = 2