mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
test(Microphone): adding missing testsand support for float sample rate
This commit is contained in:
@@ -25,7 +25,6 @@ from threading import Barrier, Event, Event as thread_Event, Thread
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sounddevice as sd
|
|
||||||
from soundfile import SoundFile
|
from soundfile import SoundFile
|
||||||
|
|
||||||
from lerobot.utils.errors import (
|
from lerobot.utils.errors import (
|
||||||
@@ -111,10 +110,15 @@ class PortAudioMicrophone(Microphone):
|
|||||||
return self.write_thread is not None and self.write_thread.is_alive()
|
return self.write_thread is not None and self.write_thread.is_alive()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_microphones(sounddevice_sdk: ISounddeviceSDK = None) -> list[dict[str, Any]]:
|
def find_microphones(
|
||||||
|
device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None
|
||||||
|
) -> list[dict[str, Any]] | dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Detects available microphones connected to the system.
|
Detects available microphones connected to the system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: The device to find microphones for. If None, all microphones are found.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, Any]]: A list of dictionaries,
|
List[Dict[str, Any]]: A list of dictionaries,
|
||||||
where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
|
where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
|
||||||
@@ -126,15 +130,29 @@ class PortAudioMicrophone(Microphone):
|
|||||||
found_microphones_info = []
|
found_microphones_info = []
|
||||||
|
|
||||||
devices = sounddevice_sdk.query_devices()
|
devices = sounddevice_sdk.query_devices()
|
||||||
for device in devices:
|
for d in devices:
|
||||||
if device["max_input_channels"] > 0:
|
if d["max_input_channels"] > 0:
|
||||||
microphone_info = {
|
microphone_info = {
|
||||||
"index": device["index"],
|
"index": d["index"],
|
||||||
"name": device["name"],
|
"name": d["name"],
|
||||||
"sample_rate": int(device["default_samplerate"]),
|
"sample_rate": int(d["default_samplerate"]),
|
||||||
"channels": list(range(1, device["max_input_channels"] + 1)),
|
"channels": np.arange(1, d["max_input_channels"] + 1),
|
||||||
}
|
}
|
||||||
found_microphones_info.append(microphone_info)
|
|
||||||
|
if device is None or (
|
||||||
|
(isinstance(device, int) and d["index"] == device)
|
||||||
|
or (isinstance(device, str) and d["name"] == device)
|
||||||
|
):
|
||||||
|
found_microphones_info.append(microphone_info)
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
if len(found_microphones_info) == 0:
|
||||||
|
raise RuntimeError(f"No microphone found for device {device}")
|
||||||
|
else:
|
||||||
|
return found_microphones_info[0]
|
||||||
|
|
||||||
|
if len(found_microphones_info) == 0:
|
||||||
|
logging.warning("No microphone found !")
|
||||||
|
|
||||||
return found_microphones_info
|
return found_microphones_info
|
||||||
|
|
||||||
@@ -160,24 +178,28 @@ class PortAudioMicrophone(Microphone):
|
|||||||
def _validate_microphone_index(self) -> None:
|
def _validate_microphone_index(self) -> None:
|
||||||
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
|
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
|
||||||
|
|
||||||
is_index_input = (
|
try:
|
||||||
self.microphone_index >= 0
|
PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)
|
||||||
and self.sounddevice_sdk.query_devices(self.microphone_index)["max_input_channels"] > 0
|
except RuntimeError as e:
|
||||||
)
|
|
||||||
|
|
||||||
if not is_index_input:
|
|
||||||
found_microphones_info = self.find_microphones()
|
|
||||||
available_microphones = {m["name"]: m["index"] for m in found_microphones_info}
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Microphone index {self.microphone_index} does not match an input device (microphone). Available input devices : {available_microphones}"
|
f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
def _validate_sample_rate(self) -> None:
|
def _validate_sample_rate(self) -> None:
|
||||||
"""Validates the sample rate against the actual microphone's default sample rate."""
|
"""Validates the sample rate against the actual microphone's default sample rate."""
|
||||||
|
|
||||||
actual_sample_rate = self.sounddevice_sdk.query_devices(self.microphone_index)["default_samplerate"]
|
actual_sample_rate = PortAudioMicrophone.find_microphones(
|
||||||
|
self.microphone_index, self.sounddevice_sdk
|
||||||
|
)["sample_rate"]
|
||||||
|
|
||||||
if self.sample_rate is not None:
|
if self.sample_rate is not None:
|
||||||
|
try:
|
||||||
|
self.sample_rate = int(self.sample_rate)
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer."
|
||||||
|
) from e
|
||||||
|
|
||||||
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
|
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
|
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
|
||||||
@@ -187,22 +209,25 @@ class PortAudioMicrophone(Microphone):
|
|||||||
logging.warning(
|
logging.warning(
|
||||||
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||||
)
|
)
|
||||||
self.sample_rate = int(self.sample_rate)
|
|
||||||
else:
|
else:
|
||||||
self.sample_rate = int(actual_sample_rate)
|
self.sample_rate = actual_sample_rate
|
||||||
|
|
||||||
def _validate_channels(self) -> None:
|
def _validate_channels(self) -> None:
|
||||||
"""Validates the channels against the actual microphone's maximum input channels."""
|
"""Validates the channels against the actual microphone's maximum input channels."""
|
||||||
|
|
||||||
actual_max_microphone_channels = sd.query_devices(self.microphone_index)["max_input_channels"]
|
actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[
|
||||||
|
"channels"
|
||||||
|
]
|
||||||
|
|
||||||
if self.channels is not None and len(self.channels) > 0:
|
if self.channels is not None and len(self.channels) > 0:
|
||||||
if any(c > actual_max_microphone_channels or c <= 0 for c in self.channels):
|
if any(
|
||||||
|
all(c > actual_channels) or c <= 0 or not isinstance(c, np.integer) for c in self.channels
|
||||||
|
):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Some of the provided channels {self.channels} are outside the maximum channel range of the microphone {actual_max_microphone_channels}."
|
f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.channels = np.arange(1, actual_max_microphone_channels + 1)
|
self.channels = actual_channels
|
||||||
|
|
||||||
# Get channels index instead of number for slicing
|
# Get channels index instead of number for slicing
|
||||||
self.channels_index = np.array(self.channels) - 1
|
self.channels_index = np.array(self.channels) - 1
|
||||||
|
|||||||
@@ -93,6 +93,16 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
|
|||||||
|
|
||||||
self.default_config = self._create_config(kind="input")
|
self.default_config = self._create_config(kind="input")
|
||||||
|
|
||||||
|
def test_find_microphones(self):
|
||||||
|
microphones = PortAudioMicrophone.find_microphones(sounddevice_sdk=self.test_sdk)
|
||||||
|
|
||||||
|
for microphone in microphones:
|
||||||
|
self.assertIsInstance(microphone["index"], int)
|
||||||
|
self.assertIsInstance(microphone["name"], str)
|
||||||
|
self.assertIsInstance(microphone["sample_rate"], int)
|
||||||
|
self.assertIsInstance(microphone["channels"], np.ndarray)
|
||||||
|
self.assertGreater(len(microphone["channels"]), 0)
|
||||||
|
|
||||||
def test_init_defaults(self):
|
def test_init_defaults(self):
|
||||||
microphone = PortAudioMicrophone(self.default_config, sounddevice_sdk=self.test_sdk)
|
microphone = PortAudioMicrophone(self.default_config, sounddevice_sdk=self.test_sdk)
|
||||||
|
|
||||||
@@ -153,6 +163,15 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
|
|||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
microphone.connect()
|
microphone.connect()
|
||||||
|
|
||||||
|
def test_connect_float_sample_rate(self):
|
||||||
|
config = deepcopy(self.default_config)
|
||||||
|
config.sample_rate = int(config.sample_rate) - 0.5
|
||||||
|
microphone = PortAudioMicrophone(config, sounddevice_sdk=self.test_sdk)
|
||||||
|
microphone.connect()
|
||||||
|
|
||||||
|
self.assertIsInstance(microphone.sample_rate, int)
|
||||||
|
self.assertEqual(microphone.sample_rate, int(config.sample_rate))
|
||||||
|
|
||||||
def test_connect_lower_sample_rate(self):
|
def test_connect_lower_sample_rate(self):
|
||||||
config = deepcopy(self.default_config)
|
config = deepcopy(self.default_config)
|
||||||
config.sample_rate = 1000 # Lowest possible sample rate
|
config.sample_rate = 1000 # Lowest possible sample rate
|
||||||
@@ -291,7 +310,7 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
data.shape[0],
|
data.shape[0],
|
||||||
RECORDING_DURATION * self.default_config.sample_rate,
|
RECORDING_DURATION * self.default_config.sample_rate,
|
||||||
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
|
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_writing_success(self):
|
def test_writing_success(self):
|
||||||
@@ -311,7 +330,7 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
data.shape[0],
|
data.shape[0],
|
||||||
RECORDING_DURATION * self.default_config.sample_rate,
|
RECORDING_DURATION * self.default_config.sample_rate,
|
||||||
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
|
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_read_while_writing(self):
|
def test_read_while_writing(self):
|
||||||
@@ -330,12 +349,12 @@ class TestPortAudioMicrophoneDeviceValidation(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
writing_data.shape[0],
|
writing_data.shape[0],
|
||||||
RECORDING_DURATION * self.default_config.sample_rate,
|
RECORDING_DURATION * self.default_config.sample_rate,
|
||||||
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
|
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||||
)
|
)
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
read_data.shape[0],
|
read_data.shape[0],
|
||||||
RECORDING_DURATION * self.default_config.sample_rate,
|
RECORDING_DURATION * self.default_config.sample_rate,
|
||||||
delta=self.default_config.sample_rate * device_info["default_low_input_latency"],
|
delta=2 * self.default_config.sample_rate * device_info["default_low_input_latency"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_async_start_recording(self):
|
def test_async_start_recording(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user