From 0232879245a9a2f478c9b63e0ada4a399fc43865 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 7 Aug 2025 11:35:08 +0200 Subject: [PATCH] test(Microphone): adding missing testsand support for float sample rate --- .../portaudio/microphone_portaudio.py | 77 ++++++++++++------- tests/microphones/test_portaudio.py | 27 ++++++- 2 files changed, 74 insertions(+), 30 deletions(-) diff --git a/src/lerobot/microphones/portaudio/microphone_portaudio.py b/src/lerobot/microphones/portaudio/microphone_portaudio.py index 01fbf9c87..c471584b5 100644 --- a/src/lerobot/microphones/portaudio/microphone_portaudio.py +++ b/src/lerobot/microphones/portaudio/microphone_portaudio.py @@ -25,7 +25,6 @@ from threading import Barrier, Event, Event as thread_Event, Thread from typing import Any import numpy as np -import sounddevice as sd from soundfile import SoundFile from lerobot.utils.errors import ( @@ -111,10 +110,15 @@ class PortAudioMicrophone(Microphone): return self.write_thread is not None and self.write_thread.is_alive() @staticmethod - def find_microphones(sounddevice_sdk: ISounddeviceSDK = None) -> list[dict[str, Any]]: + def find_microphones( + device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None + ) -> list[dict[str, Any]] | dict[str, Any]: """ Detects available microphones connected to the system. + Args: + device: The device to find microphones for. If None, all microphones are found. + Returns: List[Dict[str, Any]]: A list of dictionaries, where each dictionary contains information about a detected microphone : index, name, sample rate, channels. @@ -126,15 +130,29 @@ class PortAudioMicrophone(Microphone): found_microphones_info = [] devices = sounddevice_sdk.query_devices() - for device in devices: - if device["max_input_channels"] > 0: + for d in devices: + if d["max_input_channels"] > 0: microphone_info = { - "index": device["index"], - "name": device["name"], - "sample_rate": int(device["default_samplerate"]), - "channels": list(range(1, device["max_input_channels"] + 1)), + "index": d["index"], + "name": d["name"], + "sample_rate": int(d["default_samplerate"]), + "channels": np.arange(1, d["max_input_channels"] + 1), } - found_microphones_info.append(microphone_info) + + if device is None or ( + (isinstance(device, int) and d["index"] == device) + or (isinstance(device, str) and d["name"] == device) + ): + found_microphones_info.append(microphone_info) + + if device is not None: + if len(found_microphones_info) == 0: + raise RuntimeError(f"No microphone found for device {device}") + else: + return found_microphones_info[0] + + if len(found_microphones_info) == 0: + logging.warning("No microphone found !") return found_microphones_info @@ -160,24 +178,28 @@ class PortAudioMicrophone(Microphone): def _validate_microphone_index(self) -> None: """ "Validates the microphone index against available devices by checking if it has at least one input channel.""" - is_index_input = ( - self.microphone_index >= 0 - and self.sounddevice_sdk.query_devices(self.microphone_index)["max_input_channels"] > 0 - ) - - if not is_index_input: - found_microphones_info = self.find_microphones() - available_microphones = {m["name"]: m["index"] for m in found_microphones_info} + try: + PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk) + except RuntimeError as e: raise RuntimeError( - f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}" - ) + f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}" + ) from e def _validate_sample_rate(self) -> None: """Validates the sample rate against the actual microphone's default sample rate.""" - actual_sample_rate = self.sounddevice_sdk.query_devices(self.microphone_index)["default_samplerate"] + actual_sample_rate = PortAudioMicrophone.find_microphones( + self.microphone_index, self.sounddevice_sdk + )["sample_rate"] if self.sample_rate is not None: + try: + self.sample_rate = int(self.sample_rate) + except ValueError as e: + raise RuntimeError( + f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer." + ) from e + if self.sample_rate > actual_sample_rate or self.sample_rate < 1000: raise RuntimeError( f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}." @@ -187,22 +209,25 @@ class PortAudioMicrophone(Microphone): logging.warning( "Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted." ) - self.sample_rate = int(self.sample_rate) else: - self.sample_rate = int(actual_sample_rate) + self.sample_rate = actual_sample_rate def _validate_channels(self) -> None: """Validates the channels against the actual microphone's maximum input channels.""" - actual_max_microphone_channels = sd.query_devices(self.microphone_index)["max_input_channels"] + actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[ + "channels" + ] if self.channels is not None and len(self.channels) > 0: - if any(c > actual_max_microphone_channels or c <= 0 for c in self.channels): + if any( + all(c > actual_channels) or c <= 0 or not isinstance(c, np.integer) for c in self.channels + ): raise RuntimeError( - f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_max_microphone_channels}." + f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}." ) else: - self.channels = np.arange(1, actual_max_microphone_channels + 1) + self.channels = actual_channels # Get channels index instead of number for slicing self.channels_index = np.array(self.channels) - 1 diff --git a/tests/microphones/test_portaudio.py b/tests/microphones/test_portaudio.py index d10bb41e3..efb52e24d 100644 --- a/tests/microphones/test_portaudio.py +++ b/tests/microphones/test_portaudio.py @@ -93,6 +93,16 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase): self.default_config = self._create_config(kind="input") + def test_find_microphones(self): + microphones = PortAudioMicrophone.find_microphones(sounddevice_sdk=self.test_sdk) + + for microphone in microphones: + self.assertIsInstance(microphone["index"], int) + self.assertIsInstance(microphone["name"], str) + self.assertIsInstance(microphone["sample_rate"], int) + self.assertIsInstance(microphone["channels"], np.ndarray) + self.assertGreater(len(microphone["channels"]), 0) + def test_init_defaults(self): microphone = PortAudioMicrophone(self.default_config, sounddevice_sdk=self.test_sdk) @@ -153,6 +163,15 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase): with self.assertRaises(RuntimeError): microphone.connect() + def test_connect_float_sample_rate(self): + config = deepcopy(self.default_config) + config.sample_rate = int(config.sample_rate) - 0.5 + microphone = PortAudioMicrophone(config, sounddevice_sdk=self.test_sdk) + microphone.connect() + + self.assertIsInstance(microphone.sample_rate, int) + self.assertEqual(microphone.sample_rate, int(config.sample_rate)) + def test_connect_lower_sample_rate(self): config = deepcopy(self.default_config) config.sample_rate = 1000 # Lowest possible sample rate @@ -291,7 +310,7 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase): self.assertAlmostEqual( data.shape[0], RECORDING_DURATION * self.default_config.sample_rate, - delta=self.default_config.sample_rate * device_info["default_low_input_latency"], + delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"], ) def test_writing_success(self): @@ -311,7 +330,7 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase): self.assertAlmostEqual( data.shape[0], RECORDING_DURATION * self.default_config.sample_rate, - delta=self.default_config.sample_rate * device_info["default_low_input_latency"], + delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"], ) def test_read_while_writing(self): @@ -330,12 +349,12 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase): self.assertAlmostEqual( writing_data.shape[0], RECORDING_DURATION * self.default_config.sample_rate, - delta=self.default_config.sample_rate * device_info["default_low_input_latency"], + delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"], ) self.assertAlmostEqual( read_data.shape[0], RECORDING_DURATION * self.default_config.sample_rate, - delta=self.default_config.sample_rate * device_info["default_low_input_latency"], + delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"], ) def test_async_start_recording(self):