mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
test(LeRobotDataset): add missing test and support for audio frames addition
This commit is contained in:
@@ -690,7 +690,7 @@ def hw_to_dataset_features(
|
|||||||
"dtype": "audio",
|
"dtype": "audio",
|
||||||
"shape": (features[1],),
|
"shape": (features[1],),
|
||||||
"names": ["channels"],
|
"names": ["channels"],
|
||||||
"sample_rate": features[0],
|
"info": {"sample_rate": features[0]},
|
||||||
}
|
}
|
||||||
|
|
||||||
_validate_feature_names(features)
|
_validate_feature_names(features)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import torch
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
from soundfile import write
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
@@ -77,6 +78,21 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def audio_dataset_le_kiwi(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {
|
||||||
|
"audio": {
|
||||||
|
"dtype": "audio",
|
||||||
|
"shape": (1, DUMMY_AUDIO_CHANNELS),
|
||||||
|
"names": [
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, robot_type="lekiwi")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {
|
features = {
|
||||||
@@ -86,6 +102,7 @@ def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
"names": [
|
"names": [
|
||||||
"channels",
|
"channels",
|
||||||
],
|
],
|
||||||
|
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
@@ -432,8 +449,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_audio(audio_dataset):
|
def test_add_frame_audio_array(audio_dataset_le_kiwi):
|
||||||
dataset = audio_dataset
|
dataset = audio_dataset_le_kiwi
|
||||||
dataset.add_frame(
|
dataset.add_frame(
|
||||||
{
|
{
|
||||||
"audio": np.random.rand(
|
"audio": np.random.rand(
|
||||||
@@ -448,7 +465,59 @@ def test_add_frame_audio(audio_dataset):
|
|||||||
(
|
(
|
||||||
DUMMY_AUDIO_CHANNELS,
|
DUMMY_AUDIO_CHANNELS,
|
||||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
||||||
) # Match pytorch channel-first format
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_audio_array_wrong_shape(audio_dataset_le_kiwi):
|
||||||
|
dataset = audio_dataset_le_kiwi
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
dataset.add_frame(
|
||||||
|
{
|
||||||
|
"audio": np.random.rand(
|
||||||
|
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS, 99
|
||||||
|
)
|
||||||
|
},
|
||||||
|
task="Dummy task",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_audio_array_wrong_channels_number(audio_dataset_le_kiwi):
|
||||||
|
dataset = audio_dataset_le_kiwi
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
dataset.add_frame(
|
||||||
|
{"audio": np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), 99)},
|
||||||
|
task="Dummy task",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_audio_file(audio_dataset):
|
||||||
|
dataset = audio_dataset
|
||||||
|
dataset.add_frame(
|
||||||
|
{
|
||||||
|
"audio": np.random.rand(
|
||||||
|
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
|
||||||
|
)
|
||||||
|
},
|
||||||
|
task="Dummy task",
|
||||||
|
)
|
||||||
|
# Create the audio file that should be created in the background by the Microphone class
|
||||||
|
for audio_key in dataset.meta.audio_keys:
|
||||||
|
fpath = dataset._get_raw_audio_file_path(0, audio_key)
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
write(
|
||||||
|
fpath,
|
||||||
|
np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS),
|
||||||
|
DEFAULT_SAMPLE_RATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
assert dataset[0]["audio"].shape == torch.Size(
|
||||||
|
(
|
||||||
|
DUMMY_AUDIO_CHANNELS,
|
||||||
|
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user