test(Microphone): adding missing testsand support for float sample rate

This commit is contained in:
CarolinePascal
2025-08-07 11:35:08 +02:00
parent 2726b4e865
commit 0232879245
2 changed files with 74 additions and 30 deletions
@@ -25,7 +25,6 @@ from threading import Barrier, Event, Event as thread_Event, Thread
from typing import Any
import numpy as np
import sounddevice as sd
from soundfile import SoundFile
from lerobot.utils.errors import (
@@ -111,10 +110,15 @@ class PortAudioMicrophone(Microphone):
return self.write_thread is not None and self.write_thread.is_alive()
@staticmethod
def find_microphones(sounddevice_sdk: ISounddeviceSDK = None) -> list[dict[str, Any]]:
def find_microphones(
device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None
) -> list[dict[str, Any]] | dict[str, Any]:
"""
Detects available microphones connected to the system.
Args:
device: The device to find microphones for. If None, all microphones are found.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
@@ -126,15 +130,29 @@ class PortAudioMicrophone(Microphone):
found_microphones_info = []
devices = sounddevice_sdk.query_devices()
for device in devices:
if device["max_input_channels"] > 0:
for d in devices:
if d["max_input_channels"] > 0:
microphone_info = {
"index": device["index"],
"name": device["name"],
"sample_rate": int(device["default_samplerate"]),
"channels": list(range(1, device["max_input_channels"] + 1)),
"index": d["index"],
"name": d["name"],
"sample_rate": int(d["default_samplerate"]),
"channels": np.arange(1, d["max_input_channels"] + 1),
}
found_microphones_info.append(microphone_info)
if device is None or (
(isinstance(device, int) and d["index"] == device)
or (isinstance(device, str) and d["name"] == device)
):
found_microphones_info.append(microphone_info)
if device is not None:
if len(found_microphones_info) == 0:
raise RuntimeError(f"No microphone found for device {device}")
else:
return found_microphones_info[0]
if len(found_microphones_info) == 0:
logging.warning("No microphone found !")
return found_microphones_info
@@ -160,24 +178,28 @@ class PortAudioMicrophone(Microphone):
def _validate_microphone_index(self) -> None:
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
is_index_input = (
self.microphone_index >= 0
and self.sounddevice_sdk.query_devices(self.microphone_index)["max_input_channels"] > 0
)
if not is_index_input:
found_microphones_info = self.find_microphones()
available_microphones = {m["name"]: m["index"] for m in found_microphones_info}
try:
PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)
except RuntimeError as e:
raise RuntimeError(
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
)
f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}"
) from e
def _validate_sample_rate(self) -> None:
"""Validates the sample rate against the actual microphone's default sample rate."""
actual_sample_rate = self.sounddevice_sdk.query_devices(self.microphone_index)["default_samplerate"]
actual_sample_rate = PortAudioMicrophone.find_microphones(
self.microphone_index, self.sounddevice_sdk
)["sample_rate"]
if self.sample_rate is not None:
try:
self.sample_rate = int(self.sample_rate)
except ValueError as e:
raise RuntimeError(
f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer."
) from e
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
raise RuntimeError(
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
@@ -187,22 +209,25 @@ class PortAudioMicrophone(Microphone):
logging.warning(
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
)
self.sample_rate = int(self.sample_rate)
else:
self.sample_rate = int(actual_sample_rate)
self.sample_rate = actual_sample_rate
def _validate_channels(self) -> None:
"""Validates the channels against the actual microphone's maximum input channels."""
actual_max_microphone_channels = sd.query_devices(self.microphone_index)["max_input_channels"]
actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[
"channels"
]
if self.channels is not None and len(self.channels) > 0:
if any(c > actual_max_microphone_channels or c <= 0 for c in self.channels):
if any(
all(c > actual_channels) or c <= 0 or not isinstance(c, np.integer) for c in self.channels
):
raise RuntimeError(
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_max_microphone_channels}."
f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}."
)
else:
self.channels = np.arange(1, actual_max_microphone_channels + 1)
self.channels = actual_channels
# Get channels index instead of number for slicing
self.channels_index = np.array(self.channels) - 1
+23 -4
View File
@@ -93,6 +93,16 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
self.default_config = self._create_config(kind="input")
def test_find_microphones(self):
microphones = PortAudioMicrophone.find_microphones(sounddevice_sdk=self.test_sdk)
for microphone in microphones:
self.assertIsInstance(microphone["index"], int)
self.assertIsInstance(microphone["name"], str)
self.assertIsInstance(microphone["sample_rate"], int)
self.assertIsInstance(microphone["channels"], np.ndarray)
self.assertGreater(len(microphone["channels"]), 0)
def test_init_defaults(self):
microphone = PortAudioMicrophone(self.default_config, sounddevice_sdk=self.test_sdk)
@@ -153,6 +163,15 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
with self.assertRaises(RuntimeError):
microphone.connect()
def test_connect_float_sample_rate(self):
config = deepcopy(self.default_config)
config.sample_rate = int(config.sample_rate) - 0.5
microphone = PortAudioMicrophone(config, sounddevice_sdk=self.test_sdk)
microphone.connect()
self.assertIsInstance(microphone.sample_rate, int)
self.assertEqual(microphone.sample_rate, int(config.sample_rate))
def test_connect_lower_sample_rate(self):
config = deepcopy(self.default_config)
config.sample_rate = 1000 # Lowest possible sample rate
@@ -291,7 +310,7 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
self.assertAlmostEqual(
data.shape[0],
RECORDING_DURATION * self.default_config.sample_rate,
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
)
def test_writing_success(self):
@@ -311,7 +330,7 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
self.assertAlmostEqual(
data.shape[0],
RECORDING_DURATION * self.default_config.sample_rate,
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
)
def test_read_while_writing(self):
@@ -330,12 +349,12 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
self.assertAlmostEqual(
writing_data.shape[0],
RECORDING_DURATION * self.default_config.sample_rate,
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
)
self.assertAlmostEqual(
read_data.shape[0],
RECORDING_DURATION * self.default_config.sample_rate,
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
)
def test_async_start_recording(self):