mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
test(Microphone): adding missing testsand support for float sample rate
This commit is contained in:
@@ -25,7 +25,6 @@ from threading import Barrier, Event, Event as thread_Event, Thread
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
from soundfile import SoundFile
|
||||
|
||||
from lerobot.utils.errors import (
|
||||
@@ -111,10 +110,15 @@ class PortAudioMicrophone(Microphone):
|
||||
return self.write_thread is not None and self.write_thread.is_alive()
|
||||
|
||||
@staticmethod
|
||||
def find_microphones(sounddevice_sdk: ISounddeviceSDK = None) -> list[dict[str, Any]]:
|
||||
def find_microphones(
|
||||
device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None
|
||||
) -> list[dict[str, Any]] | dict[str, Any]:
|
||||
"""
|
||||
Detects available microphones connected to the system.
|
||||
|
||||
Args:
|
||||
device: The device to find microphones for. If None, all microphones are found.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
|
||||
@@ -126,15 +130,29 @@ class PortAudioMicrophone(Microphone):
|
||||
found_microphones_info = []
|
||||
|
||||
devices = sounddevice_sdk.query_devices()
|
||||
for device in devices:
|
||||
if device["max_input_channels"] > 0:
|
||||
for d in devices:
|
||||
if d["max_input_channels"] > 0:
|
||||
microphone_info = {
|
||||
"index": device["index"],
|
||||
"name": device["name"],
|
||||
"sample_rate": int(device["default_samplerate"]),
|
||||
"channels": list(range(1, device["max_input_channels"] + 1)),
|
||||
"index": d["index"],
|
||||
"name": d["name"],
|
||||
"sample_rate": int(d["default_samplerate"]),
|
||||
"channels": np.arange(1, d["max_input_channels"] + 1),
|
||||
}
|
||||
found_microphones_info.append(microphone_info)
|
||||
|
||||
if device is None or (
|
||||
(isinstance(device, int) and d["index"] == device)
|
||||
or (isinstance(device, str) and d["name"] == device)
|
||||
):
|
||||
found_microphones_info.append(microphone_info)
|
||||
|
||||
if device is not None:
|
||||
if len(found_microphones_info) == 0:
|
||||
raise RuntimeError(f"No microphone found for device {device}")
|
||||
else:
|
||||
return found_microphones_info[0]
|
||||
|
||||
if len(found_microphones_info) == 0:
|
||||
logging.warning("No microphone found !")
|
||||
|
||||
return found_microphones_info
|
||||
|
||||
@@ -160,24 +178,28 @@ class PortAudioMicrophone(Microphone):
|
||||
def _validate_microphone_index(self) -> None:
|
||||
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
|
||||
|
||||
is_index_input = (
|
||||
self.microphone_index >= 0
|
||||
and self.sounddevice_sdk.query_devices(self.microphone_index)["max_input_channels"] > 0
|
||||
)
|
||||
|
||||
if not is_index_input:
|
||||
found_microphones_info = self.find_microphones()
|
||||
available_microphones = {m["name"]: m["index"] for m in found_microphones_info}
|
||||
try:
|
||||
PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
|
||||
)
|
||||
f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}"
|
||||
) from e
|
||||
|
||||
def _validate_sample_rate(self) -> None:
|
||||
"""Validates the sample rate against the actual microphone's default sample rate."""
|
||||
|
||||
actual_sample_rate = self.sounddevice_sdk.query_devices(self.microphone_index)["default_samplerate"]
|
||||
actual_sample_rate = PortAudioMicrophone.find_microphones(
|
||||
self.microphone_index, self.sounddevice_sdk
|
||||
)["sample_rate"]
|
||||
|
||||
if self.sample_rate is not None:
|
||||
try:
|
||||
self.sample_rate = int(self.sample_rate)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer."
|
||||
) from e
|
||||
|
||||
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
|
||||
raise RuntimeError(
|
||||
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
|
||||
@@ -187,22 +209,25 @@ class PortAudioMicrophone(Microphone):
|
||||
logging.warning(
|
||||
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||
)
|
||||
self.sample_rate = int(self.sample_rate)
|
||||
else:
|
||||
self.sample_rate = int(actual_sample_rate)
|
||||
self.sample_rate = actual_sample_rate
|
||||
|
||||
def _validate_channels(self) -> None:
|
||||
"""Validates the channels against the actual microphone's maximum input channels."""
|
||||
|
||||
actual_max_microphone_channels = sd.query_devices(self.microphone_index)["max_input_channels"]
|
||||
actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[
|
||||
"channels"
|
||||
]
|
||||
|
||||
if self.channels is not None and len(self.channels) > 0:
|
||||
if any(c > actual_max_microphone_channels or c <= 0 for c in self.channels):
|
||||
if any(
|
||||
all(c > actual_channels) or c <= 0 or not isinstance(c, np.integer) for c in self.channels
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_max_microphone_channels}."
|
||||
f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}."
|
||||
)
|
||||
else:
|
||||
self.channels = np.arange(1, actual_max_microphone_channels + 1)
|
||||
self.channels = actual_channels
|
||||
|
||||
# Get channels index instead of number for slicing
|
||||
self.channels_index = np.array(self.channels) - 1
|
||||
|
||||
@@ -93,6 +93,16 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
|
||||
|
||||
self.default_config = self._create_config(kind="input")
|
||||
|
||||
def test_find_microphones(self):
|
||||
microphones = PortAudioMicrophone.find_microphones(sounddevice_sdk=self.test_sdk)
|
||||
|
||||
for microphone in microphones:
|
||||
self.assertIsInstance(microphone["index"], int)
|
||||
self.assertIsInstance(microphone["name"], str)
|
||||
self.assertIsInstance(microphone["sample_rate"], int)
|
||||
self.assertIsInstance(microphone["channels"], np.ndarray)
|
||||
self.assertGreater(len(microphone["channels"]), 0)
|
||||
|
||||
def test_init_defaults(self):
|
||||
microphone = PortAudioMicrophone(self.default_config, sounddevice_sdk=self.test_sdk)
|
||||
|
||||
@@ -153,6 +163,15 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
microphone.connect()
|
||||
|
||||
def test_connect_float_sample_rate(self):
|
||||
config = deepcopy(self.default_config)
|
||||
config.sample_rate = int(config.sample_rate) - 0.5
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=self.test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
self.assertIsInstance(microphone.sample_rate, int)
|
||||
self.assertEqual(microphone.sample_rate, int(config.sample_rate))
|
||||
|
||||
def test_connect_lower_sample_rate(self):
|
||||
config = deepcopy(self.default_config)
|
||||
config.sample_rate = 1000 # Lowest possible sample rate
|
||||
@@ -291,7 +310,7 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
|
||||
self.assertAlmostEqual(
|
||||
data.shape[0],
|
||||
RECORDING_DURATION * self.default_config.sample_rate,
|
||||
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||
)
|
||||
|
||||
def test_writing_success(self):
|
||||
@@ -311,7 +330,7 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
|
||||
self.assertAlmostEqual(
|
||||
data.shape[0],
|
||||
RECORDING_DURATION * self.default_config.sample_rate,
|
||||
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||
)
|
||||
|
||||
def test_read_while_writing(self):
|
||||
@@ -330,12 +349,12 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
|
||||
self.assertAlmostEqual(
|
||||
writing_data.shape[0],
|
||||
RECORDING_DURATION * self.default_config.sample_rate,
|
||||
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
read_data.shape[0],
|
||||
RECORDING_DURATION * self.default_config.sample_rate,
|
||||
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||
)
|
||||
|
||||
def test_async_start_recording(self):
|
||||
|
||||
Reference in New Issue
Block a user