test(LeRobotDataset): add missing test and support for audio frames addition

This commit is contained in:
CarolinePascal
2025-08-07 11:35:51 +02:00
parent 0232879245
commit ef8f40c21b
2 changed files with 73 additions and 4 deletions
+72 -3
View File
@@ -24,6 +24,7 @@ import torch
from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file
from soundfile import write
import lerobot
from lerobot.configs.default import DatasetConfig
@@ -77,6 +78,21 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
@pytest.fixture
def audio_dataset_le_kiwi(tmp_path, empty_lerobot_dataset_factory):
features = {
"audio": {
"dtype": "audio",
"shape": (1, DUMMY_AUDIO_CHANNELS),
"names": [
"channels",
],
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
}
}
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, robot_type="lekiwi")
@pytest.fixture
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
features = {
@@ -86,6 +102,7 @@ def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
"names": [
"channels",
],
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
}
}
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
@@ -432,8 +449,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
)
def test_add_frame_audio(audio_dataset):
dataset = audio_dataset
def test_add_frame_audio_array(audio_dataset_le_kiwi):
dataset = audio_dataset_le_kiwi
dataset.add_frame(
{
"audio": np.random.rand(
@@ -448,7 +465,59 @@ def test_add_frame_audio(audio_dataset):
(
DUMMY_AUDIO_CHANNELS,
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
) # Match pytorch channel-first format
)
)
def test_add_frame_audio_array_wrong_shape(audio_dataset_le_kiwi):
dataset = audio_dataset_le_kiwi
with pytest.raises(ValueError):
dataset.add_frame(
{
"audio": np.random.rand(
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS, 99
)
},
task="Dummy task",
)
def test_add_frame_audio_array_wrong_channels_number(audio_dataset_le_kiwi):
dataset = audio_dataset_le_kiwi
with pytest.raises(ValueError):
dataset.add_frame(
{"audio": np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), 99)},
task="Dummy task",
)
def test_add_frame_audio_file(audio_dataset):
dataset = audio_dataset
dataset.add_frame(
{
"audio": np.random.rand(
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
)
},
task="Dummy task",
)
# Create the audio file that should be created in the background by the Microphone class
for audio_key in dataset.meta.audio_keys:
fpath = dataset._get_raw_audio_file_path(0, audio_key)
fpath.parent.mkdir(parents=True, exist_ok=True)
write(
fpath,
np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS),
DEFAULT_SAMPLE_RATE,
)
dataset.save_episode()
assert dataset[0]["audio"].shape == torch.Size(
(
DUMMY_AUDIO_CHANNELS,
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
)
)