mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
Adding audio tests
This commit is contained in:
@@ -26,16 +26,21 @@ from lerobot.datasets.compute_stats import (
|
||||
compute_episode_stats,
|
||||
estimate_num_samples,
|
||||
get_feature_stats,
|
||||
sample_audio,
|
||||
sample_images,
|
||||
sample_indices,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_AUDIO, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
def mock_load_audio(path):
|
||||
return np.ones((16000, 2), dtype=np.float32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array():
|
||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
@@ -73,6 +78,16 @@ def test_sample_images(mock_load):
|
||||
assert len(images) == estimate_num_samples(100)
|
||||
|
||||
|
||||
@patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio)
|
||||
def test_sample_audio(mock_load):
|
||||
audio_path = "audio.wav"
|
||||
audio_samples = sample_audio(audio_path)
|
||||
assert isinstance(audio_samples, np.ndarray)
|
||||
assert audio_samples.shape[1] == 2
|
||||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
|
||||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
@@ -81,6 +96,14 @@ def test_get_feature_stats_images():
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_audio():
|
||||
data = np.random.uniform(-1, 1, (16000, 2))
|
||||
stats = get_feature_stats(data, axis=0, keepdims=True)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
np.testing.assert_equal(stats["count"], np.array([16000]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
"min": np.array([[1, 2, 3]]),
|
||||
@@ -145,20 +168,27 @@ def test_get_feature_stats_single_value():
|
||||
def test_compute_episode_stats():
|
||||
episode_data = {
|
||||
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
|
||||
OBS_AUDIO: "audio.wav",
|
||||
OBS_STATE: np.random.rand(100, 10),
|
||||
}
|
||||
features = {
|
||||
OBS_IMAGE: {"dtype": "image"},
|
||||
OBS_AUDIO: {"dtype": "audio"},
|
||||
OBS_STATE: {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
||||
with (
|
||||
patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy),
|
||||
patch("lerobot.datasets.compute_stats.load_audio", side_effect=mock_load_audio),
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
assert OBS_IMAGE in stats and OBS_STATE in stats
|
||||
assert stats[OBS_IMAGE]["count"].item() == 100
|
||||
assert stats[OBS_STATE]["count"].item() == 100
|
||||
assert OBS_IMAGE in stats and OBS_AUDIO in stats and OBS_STATE in stats
|
||||
assert stats[OBS_IMAGE]["count"].item() == estimate_num_samples(100)
|
||||
assert stats[OBS_AUDIO]["count"].item() == estimate_num_samples(16000)
|
||||
assert stats[OBS_STATE]["count"].item() == estimate_num_samples(100)
|
||||
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
|
||||
assert stats[OBS_AUDIO]["mean"].shape == (1, 2)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_valid():
|
||||
|
||||
Reference in New Issue
Block a user