Adding audio tests

This commit is contained in:
CarolinePascal
2025-04-07 16:36:04 +02:00
parent 373a169bd2
commit be09a59e05
8 changed files with 122 additions and 8 deletions
+35 -5
View File
@@ -26,16 +26,21 @@ from lerobot.datasets.compute_stats import (
compute_episode_stats,
estimate_num_samples,
get_feature_stats,
sample_audio,
sample_images,
sample_indices,
)
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
from lerobot.utils.constants import OBS_AUDIO, OBS_IMAGE, OBS_STATE
def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
def mock_load_audio(path):
return np.ones((16000, 2), dtype=np.float32)
@pytest.fixture
def sample_array():
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@@ -73,6 +78,16 @@ def test_sample_images(mock_load):
assert len(images) == estimate_num_samples(100)
@patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio)
def test_sample_audio(mock_load):
audio_path = "audio.wav"
audio_samples = sample_audio(audio_path)
assert isinstance(audio_samples, np.ndarray)
assert audio_samples.shape[1] == 2
assert audio_samples.dtype == np.float32
assert len(audio_samples) == estimate_num_samples(16000)
def test_get_feature_stats_images():
data = np.random.rand(100, 3, 32, 32)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
@@ -81,6 +96,14 @@ def test_get_feature_stats_images():
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_audio():
data = np.random.uniform(-1, 1, (16000, 2))
stats = get_feature_stats(data, axis=0, keepdims=True)
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
np.testing.assert_equal(stats["count"], np.array([16000]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_axis_0_keepdims(sample_array):
expected = {
"min": np.array([[1, 2, 3]]),
@@ -145,20 +168,27 @@ def test_get_feature_stats_single_value():
def test_compute_episode_stats():
episode_data = {
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
OBS_AUDIO: "audio.wav",
OBS_STATE: np.random.rand(100, 10),
}
features = {
OBS_IMAGE: {"dtype": "image"},
OBS_AUDIO: {"dtype": "audio"},
OBS_STATE: {"dtype": "numeric"},
}
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
with (
patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy),
patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio),
):
stats = compute_episode_stats(episode_data, features)
assert OBS_IMAGE in stats and OBS_STATE in stats
assert stats[OBS_IMAGE]["count"].item() == 100
assert stats[OBS_STATE]["count"].item() == 100
assert OBS_IMAGE in stats and OBS_AUDIO in stats and OBS_STATE in stats
assert stats[OBS_IMAGE]["count"].item() == estimate_num_samples(100)
assert stats[OBS_AUDIO]["count"].item() == estimate_num_samples(16000)
assert stats[OBS_STATE]["count"].item() == estimate_num_samples(100)
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
assert stats[OBS_AUDIO]["mean"].shape == (1, 2)
def test_assert_type_and_shape_valid():