test(Microphone): adding missing testsand support for float sample rate

This commit is contained in:
CarolinePascal
2025-08-07 11:35:08 +02:00
parent 2726b4e865
commit 0232879245
2 changed files with 74 additions and 30 deletions
@@ -25,7 +25,6 @@ from threading import Barrier, Event, Event as thread_Event, Thread
from typing import Any from typing import Any
import numpy as np import numpy as np
import sounddevice as sd
from soundfile import SoundFile from soundfile import SoundFile
from lerobot.utils.errors import ( from lerobot.utils.errors import (
@@ -111,10 +110,15 @@ class PortAudioMicrophone(Microphone):
return self.write_thread is not None and self.write_thread.is_alive() return self.write_thread is not None and self.write_thread.is_alive()
@staticmethod @staticmethod
def find_microphones(sounddevice_sdk: ISounddeviceSDK = None) -> list[dict[str, Any]]: def find_microphones(
device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None
) -> list[dict[str, Any]] | dict[str, Any]:
""" """
Detects available microphones connected to the system. Detects available microphones connected to the system.
Args:
device: The device to find microphones for. If None, all microphones are found.
Returns: Returns:
List[Dict[str, Any]]: A list of dictionaries, List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains information about a detected microphone : index, name, sample rate, channels. where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
@@ -126,15 +130,29 @@ class PortAudioMicrophone(Microphone):
found_microphones_info = [] found_microphones_info = []
devices = sounddevice_sdk.query_devices() devices = sounddevice_sdk.query_devices()
for device in devices: for d in devices:
if device["max_input_channels"] > 0: if d["max_input_channels"] > 0:
microphone_info = { microphone_info = {
"index": device["index"], "index": d["index"],
"name": device["name"], "name": d["name"],
"sample_rate": int(device["default_samplerate"]), "sample_rate": int(d["default_samplerate"]),
"channels": list(range(1, device["max_input_channels"] + 1)), "channels": np.arange(1, d["max_input_channels"] + 1),
} }
found_microphones_info.append(microphone_info)
if device is None or (
(isinstance(device, int) and d["index"] == device)
or (isinstance(device, str) and d["name"] == device)
):
found_microphones_info.append(microphone_info)
if device is not None:
if len(found_microphones_info) == 0:
raise RuntimeError(f"No microphone found for device {device}")
else:
return found_microphones_info[0]
if len(found_microphones_info) == 0:
logging.warning("No microphone found !")
return found_microphones_info return found_microphones_info
@@ -160,24 +178,28 @@ class PortAudioMicrophone(Microphone):
def _validate_microphone_index(self) -> None: def _validate_microphone_index(self) -> None:
""" "Validates the microphone index against available devices by checking if it has at least one input channel.""" """ "Validates the microphone index against available devices by checking if it has at least one input channel."""
is_index_input = ( try:
self.microphone_index >= 0 PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)
and self.sounddevice_sdk.query_devices(self.microphone_index)["max_input_channels"] > 0 except RuntimeError as e:
)
if not is_index_input:
found_microphones_info = self.find_microphones()
available_microphones = {m["name"]: m["index"] for m in found_microphones_info}
raise RuntimeError( raise RuntimeError(
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}" f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}"
) ) from e
def _validate_sample_rate(self) -> None: def _validate_sample_rate(self) -> None:
"""Validates the sample rate against the actual microphone's default sample rate.""" """Validates the sample rate against the actual microphone's default sample rate."""
actual_sample_rate = self.sounddevice_sdk.query_devices(self.microphone_index)["default_samplerate"] actual_sample_rate = PortAudioMicrophone.find_microphones(
self.microphone_index, self.sounddevice_sdk
)["sample_rate"]
if self.sample_rate is not None: if self.sample_rate is not None:
try:
self.sample_rate = int(self.sample_rate)
except ValueError as e:
raise RuntimeError(
f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer."
) from e
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000: if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
raise RuntimeError( raise RuntimeError(
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}." f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
@@ -187,22 +209,25 @@ class PortAudioMicrophone(Microphone):
logging.warning( logging.warning(
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted." "Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
) )
self.sample_rate = int(self.sample_rate)
else: else:
self.sample_rate = int(actual_sample_rate) self.sample_rate = actual_sample_rate
def _validate_channels(self) -> None: def _validate_channels(self) -> None:
"""Validates the channels against the actual microphone's maximum input channels.""" """Validates the channels against the actual microphone's maximum input channels."""
actual_max_microphone_channels = sd.query_devices(self.microphone_index)["max_input_channels"] actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[
"channels"
]
if self.channels is not None and len(self.channels) > 0: if self.channels is not None and len(self.channels) > 0:
if any(c > actual_max_microphone_channels or c <= 0 for c in self.channels): if any(
all(c > actual_channels) or c <= 0 or not isinstance(c, np.integer) for c in self.channels
):
raise RuntimeError( raise RuntimeError(
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_max_microphone_channels}." f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}."
) )
else: else:
self.channels = np.arange(1, actual_max_microphone_channels + 1) self.channels = actual_channels
# Get channels index instead of number for slicing # Get channels index instead of number for slicing
self.channels_index = np.array(self.channels) - 1 self.channels_index = np.array(self.channels) - 1
+23 -4
View File
@@ -93,6 +93,16 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
self.default_config = self._create_config(kind="input") self.default_config = self._create_config(kind="input")
def test_find_microphones(self):
microphones = PortAudioMicrophone.find_microphones(sounddevice_sdk=self.test_sdk)
for microphone in microphones:
self.assertIsInstance(microphone["index"], int)
self.assertIsInstance(microphone["name"], str)
self.assertIsInstance(microphone["sample_rate"], int)
self.assertIsInstance(microphone["channels"], np.ndarray)
self.assertGreater(len(microphone["channels"]), 0)
def test_init_defaults(self): def test_init_defaults(self):
microphone = PortAudioMicrophone(self.default_config, sounddevice_sdk=self.test_sdk) microphone = PortAudioMicrophone(self.default_config, sounddevice_sdk=self.test_sdk)
@@ -153,6 +163,15 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
microphone.connect() microphone.connect()
def test_connect_float_sample_rate(self):
config = deepcopy(self.default_config)
config.sample_rate = int(config.sample_rate) - 0.5
microphone = PortAudioMicrophone(config, sounddevice_sdk=self.test_sdk)
microphone.connect()
self.assertIsInstance(microphone.sample_rate, int)
self.assertEqual(microphone.sample_rate, int(config.sample_rate))
def test_connect_lower_sample_rate(self): def test_connect_lower_sample_rate(self):
config = deepcopy(self.default_config) config = deepcopy(self.default_config)
config.sample_rate = 1000 # Lowest possible sample rate config.sample_rate = 1000 # Lowest possible sample rate
@@ -291,7 +310,7 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
self.assertAlmostEqual( self.assertAlmostEqual(
data.shape[0], data.shape[0],
RECORDING_DURATION * self.default_config.sample_rate, RECORDING_DURATION * self.default_config.sample_rate,
delta=self.default_config.sample_rate * device_info["default_low_input_latency"], delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
) )
def test_writing_success(self): def test_writing_success(self):
@@ -311,7 +330,7 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
self.assertAlmostEqual( self.assertAlmostEqual(
data.shape[0], data.shape[0],
RECORDING_DURATION * self.default_config.sample_rate, RECORDING_DURATION * self.default_config.sample_rate,
delta=self.default_config.sample_rate * device_info["default_low_input_latency"], delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
) )
def test_read_while_writing(self): def test_read_while_writing(self):
@@ -330,12 +349,12 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
self.assertAlmostEqual( self.assertAlmostEqual(
writing_data.shape[0], writing_data.shape[0],
RECORDING_DURATION * self.default_config.sample_rate, RECORDING_DURATION * self.default_config.sample_rate,
delta=self.default_config.sample_rate * device_info["default_low_input_latency"], delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
) )
self.assertAlmostEqual( self.assertAlmostEqual(
read_data.shape[0], read_data.shape[0],
RECORDING_DURATION * self.default_config.sample_rate, RECORDING_DURATION * self.default_config.sample_rate,
delta=self.default_config.sample_rate * device_info["default_low_input_latency"], delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
) )
def test_async_start_recording(self): def test_async_start_recording(self):