mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
test(LeRobotDataset): add missing test and support for audio frames addition
This commit is contained in:
@@ -24,6 +24,7 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from soundfile import write
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
@@ -77,6 +78,21 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_dataset_le_kiwi(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"audio": {
|
||||
"dtype": "audio",
|
||||
"shape": (1, DUMMY_AUDIO_CHANNELS),
|
||||
"names": [
|
||||
"channels",
|
||||
],
|
||||
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
|
||||
}
|
||||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, robot_type="lekiwi")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
@@ -86,6 +102,7 @@ def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
"names": [
|
||||
"channels",
|
||||
],
|
||||
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
|
||||
}
|
||||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
@@ -432,8 +449,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio(audio_dataset):
|
||||
dataset = audio_dataset
|
||||
def test_add_frame_audio_array(audio_dataset_le_kiwi):
|
||||
dataset = audio_dataset_le_kiwi
|
||||
dataset.add_frame(
|
||||
{
|
||||
"audio": np.random.rand(
|
||||
@@ -448,7 +465,59 @@ def test_add_frame_audio(audio_dataset):
|
||||
(
|
||||
DUMMY_AUDIO_CHANNELS,
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
||||
) # Match pytorch channel-first format
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_array_wrong_shape(audio_dataset_le_kiwi):
|
||||
dataset = audio_dataset_le_kiwi
|
||||
with pytest.raises(ValueError):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"audio": np.random.rand(
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS, 99
|
||||
)
|
||||
},
|
||||
task="Dummy task",
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_array_wrong_channels_number(audio_dataset_le_kiwi):
|
||||
dataset = audio_dataset_le_kiwi
|
||||
with pytest.raises(ValueError):
|
||||
dataset.add_frame(
|
||||
{"audio": np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), 99)},
|
||||
task="Dummy task",
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_file(audio_dataset):
|
||||
dataset = audio_dataset
|
||||
dataset.add_frame(
|
||||
{
|
||||
"audio": np.random.rand(
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
|
||||
)
|
||||
},
|
||||
task="Dummy task",
|
||||
)
|
||||
# Create the audio file that should be created in the background by the Microphone class
|
||||
for audio_key in dataset.meta.audio_keys:
|
||||
fpath = dataset._get_raw_audio_file_path(0, audio_key)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
write(
|
||||
fpath,
|
||||
np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS),
|
||||
DEFAULT_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["audio"].shape == torch.Size(
|
||||
(
|
||||
DUMMY_AUDIO_CHANNELS,
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user