Adding audio tests

This commit is contained in:
CarolinePascal
2025-04-07 16:36:04 +02:00
parent 373a169bd2
commit be09a59e05
8 changed files with 122 additions and 8 deletions
+13
View File
@@ -40,5 +40,18 @@ DUMMY_VIDEO_INFO = {
"video.is_depth_map": False,
"has_audio": False,
}
DUMMY_MICROPHONE_FEATURES = {
"laptop": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None},
"phone": {"dtype": "audio", "shape": (1,), "names": ["channels"], "info": None},
}
DEFAULT_SAMPLE_RATE = 48000
DUMMY_AUDIO_CHANNELS = 2
DUMMY_AUDIO_INFO = {
"has_audio": True,
"audio.sample_rate": DEFAULT_SAMPLE_RATE,
"audio.codec": "aac",
"audio.channels": DUMMY_AUDIO_CHANNELS,
"audio.channel_layout": "stereo",
}
DUMMY_CHW = (3, 96, 128)
DUMMY_HWC = (96, 128, 3)
+13 -1
View File
@@ -43,6 +43,7 @@ from lerobot.datasets.video_utils import encode_video_frames
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
DUMMY_MICROPHONE_FEATURES,
DUMMY_MOTOR_FEATURES,
DUMMY_REPO_ID,
DUMMY_ROBOT_TYPE,
@@ -131,6 +132,7 @@ def features_factory():
def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
use_videos: bool = True,
) -> dict:
if use_videos:
@@ -142,6 +144,7 @@ def features_factory():
return {
**motor_features,
**camera_ft,
**audio_features,
**DEFAULT_FEATURES,
}
@@ -166,9 +169,10 @@ def info_factory(features_factory):
audio_path: str = DEFAULT_AUDIO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
use_videos: bool = True,
) -> dict:
features = features_factory(motor_features, camera_features, use_videos)
features = features_factory(motor_features, camera_features, audio_features, use_videos)
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
@@ -207,6 +211,14 @@ def stats_factory():
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
"count": [10],
}
elif dtype == "audio":
stats[key] = {
"mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(),
"max": np.full((shape[0],), 1, dtype=np.float32).tolist(),
"min": np.full((shape[0],), -1, dtype=np.float32).tolist(),
"std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(),
"count": [10],
}
else:
stats[key] = {
"max": np.full(shape, 1, dtype=dtype).tolist(),