Adding audio tests

This commit is contained in:
CarolinePascal
2025-04-07 16:36:04 +02:00
parent 373a169bd2
commit be09a59e05
8 changed files with 122 additions and 8 deletions
+45 -1
View File
@@ -37,6 +37,7 @@ from lerobot.datasets.lerobot_dataset import (
_encode_video_worker,
)
from lerobot.datasets.utils import (
DEFAULT_AUDIO_CHUNK_DURATION,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
@@ -49,7 +50,13 @@ from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy_config
from lerobot.robots import make_robot_from_config
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.fixtures.constants import (
DEFAULT_SAMPLE_RATE,
DUMMY_AUDIO_CHANNELS,
DUMMY_CHW,
DUMMY_HWC,
DUMMY_REPO_ID,
)
from tests.mocks.mock_robot import MockRobotConfig
from tests.utils import require_x86_64_kernel
@@ -70,6 +77,20 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
@pytest.fixture
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
features = {
"audio": {
"dtype": "audio",
"shape": (DUMMY_AUDIO_CHANNELS,),
"names": [
"channels",
],
}
}
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
"""
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
@@ -411,6 +432,23 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
)
def test_add_frame_audio(audio_dataset):
dataset = audio_dataset
dataset.add_frame(
{
"audio": np.random.rand(
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
)
},
task="Dummy task",
)
dataset.save_episode()
assert dataset[0]["audio"].shape == torch.Size(
(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS)
)
# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
@@ -450,6 +488,7 @@ def test_factory(env_name, repo_id, policy_name):
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
camera_keys = dataset.meta.camera_keys
audio_keys = dataset.meta.audio_keys
item = dataset[0]
@@ -492,6 +531,11 @@ def test_factory(env_name, repo_id, policy_name):
# test c,h,w
assert item[key].shape[0] == 3, f"{key}"
for key in audio_keys:
assert item[key].dtype == torch.float32, f"{key}"
assert item[key].max() <= 1.0, f"{key}"
assert item[key].min() >= -1.0, f"{key}"
if delta_timestamps is not None:
# test missing keys in delta_timestamps
for key in delta_timestamps: