From ef8f40c21b08e05a457bc057acba0e69c40f983e Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 7 Aug 2025 11:35:51 +0200 Subject: [PATCH] test(LeRobotDataset): add missing test and support for audio frames addition --- src/lerobot/datasets/utils.py | 2 +- tests/datasets/test_datasets.py | 75 +++++++++++++++++++++++++++++++-- 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 8926e3963..953a544b7 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -690,7 +690,7 @@ def hw_to_dataset_features( "dtype": "audio", "shape": (features[1],), "names": ["channels"], - "sample_rate": features[0], + "info": {"sample_rate": features[0]}, } _validate_feature_names(features) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index ab25a21fe..4742df028 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -24,6 +24,7 @@ import torch from huggingface_hub import HfApi from PIL import Image from safetensors.torch import load_file +from soundfile import write import lerobot from lerobot.configs.default import DatasetConfig @@ -77,6 +78,21 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory): return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) +@pytest.fixture +def audio_dataset_le_kiwi(tmp_path, empty_lerobot_dataset_factory): + features = { + "audio": { + "dtype": "audio", + "shape": (1, DUMMY_AUDIO_CHANNELS), + "names": [ + "channels", + ], + "info": {"sample_rate": DEFAULT_SAMPLE_RATE}, + } + } + return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, robot_type="lekiwi") + + @pytest.fixture def audio_dataset(tmp_path, empty_lerobot_dataset_factory): features = { @@ -86,6 +102,7 @@ def audio_dataset(tmp_path, empty_lerobot_dataset_factory): "names": [ "channels", ], + "info": {"sample_rate": DEFAULT_SAMPLE_RATE}, } } return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) @@ -432,8 +449,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory): ) -def test_add_frame_audio(audio_dataset): - dataset = audio_dataset +def test_add_frame_audio_array(audio_dataset_le_kiwi): + dataset = audio_dataset_le_kiwi dataset.add_frame( { "audio": np.random.rand( @@ -448,7 +465,59 @@ def test_add_frame_audio(audio_dataset): ( DUMMY_AUDIO_CHANNELS, int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), - ) # Match pytorch channel-first format + ) + ) + + +def test_add_frame_audio_array_wrong_shape(audio_dataset_le_kiwi): + dataset = audio_dataset_le_kiwi + with pytest.raises(ValueError): + dataset.add_frame( + { + "audio": np.random.rand( + int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS, 99 + ) + }, + task="Dummy task", + ) + + +def test_add_frame_audio_array_wrong_channels_number(audio_dataset_le_kiwi): + dataset = audio_dataset_le_kiwi + with pytest.raises(ValueError): + dataset.add_frame( + {"audio": np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), 99)}, + task="Dummy task", + ) + + +def test_add_frame_audio_file(audio_dataset): + dataset = audio_dataset + dataset.add_frame( + { + "audio": np.random.rand( + int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS + ) + }, + task="Dummy task", + ) + # Create the audio file that should be created in the background by the Microphone class + for audio_key in dataset.meta.audio_keys: + fpath = dataset._get_raw_audio_file_path(0, audio_key) + fpath.parent.mkdir(parents=True, exist_ok=True) + write( + fpath, + np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS), + DEFAULT_SAMPLE_RATE, + ) + + dataset.save_episode() + + assert dataset[0]["audio"].shape == torch.Size( + ( + DUMMY_AUDIO_CHANNELS, + int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), + ) )