mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
Adding audio tests
This commit is contained in:
@@ -37,6 +37,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
_encode_video_worker,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
@@ -49,7 +50,13 @@ from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_SAMPLE_RATE,
|
||||
DUMMY_AUDIO_CHANNELS,
|
||||
DUMMY_CHW,
|
||||
DUMMY_HWC,
|
||||
DUMMY_REPO_ID,
|
||||
)
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -70,6 +77,20 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"audio": {
|
||||
"dtype": "audio",
|
||||
"shape": (DUMMY_AUDIO_CHANNELS,),
|
||||
"names": [
|
||||
"channels",
|
||||
],
|
||||
}
|
||||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
@@ -411,6 +432,23 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio(audio_dataset):
|
||||
dataset = audio_dataset
|
||||
dataset.add_frame(
|
||||
{
|
||||
"audio": np.random.rand(
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
|
||||
)
|
||||
},
|
||||
task="Dummy task",
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["audio"].shape == torch.Size(
|
||||
(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS)
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
@@ -450,6 +488,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
audio_keys = dataset.meta.audio_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
@@ -492,6 +531,11 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
# test c,h,w
|
||||
assert item[key].shape[0] == 3, f"{key}"
|
||||
|
||||
for key in audio_keys:
|
||||
assert item[key].dtype == torch.float32, f"{key}"
|
||||
assert item[key].max() <= 1.0, f"{key}"
|
||||
assert item[key].min() >= -1.0, f"{key}"
|
||||
|
||||
if delta_timestamps is not None:
|
||||
# test missing keys in delta_timestamps
|
||||
for key in delta_timestamps:
|
||||
|
||||
Reference in New Issue
Block a user