diff --git a/src/lerobot/microphones/portaudio/microphone_portaudio.py b/src/lerobot/microphones/portaudio/microphone_portaudio.py index 408dbf293..1fca952ab 100644 --- a/src/lerobot/microphones/portaudio/microphone_portaudio.py +++ b/src/lerobot/microphones/portaudio/microphone_portaudio.py @@ -31,13 +31,14 @@ from typing import Any import numpy as np from soundfile import SoundFile +from lerobot.microphones.portaudio.interface_sounddevice_sdk import ISounddeviceSDK, SounddeviceSDKAdapter from lerobot.utils.errors import ( DeviceAlreadyConnectedError, DeviceAlreadyRecordingError, DeviceNotConnectedError, DeviceNotRecordingError, ) -from lerobot.microphones.portaudio.interface_sounddevice_sdk import ISounddeviceSDK, SounddeviceSDKAdapter +from lerobot.utils.shared_array import SharedArray from ..microphone import Microphone from .configuration_portaudio import PortAudioMicrophoneConfig @@ -93,10 +94,13 @@ class PortAudioMicrophone(Microphone): self.record_is_started_event = process_Event() self.audio_callback_start_event = process_Event() - # Process-safe concurrent queues to store the written/read audio + # Process-safe concurrent queue to send audio from the recording process to the writing process/thread + # TODO(CarolinePascal): replace by a Pipe (more efficient !) self.write_queue = process_Queue() - self.read_queue = process_Queue() + # SharedArray to store audio from the recording process. + self.read_shared_array = None + self.local_read_shared_array = None # Thread/Process to handle data writing in a separate thread/process (safely) self.write_thread = None self.write_stop_event = None @@ -246,9 +250,13 @@ class PortAudioMicrophone(Microphone): self._configure_capture_settings() - # Create or reset queues + # Create or reset queue and shared array + self.read_shared_array = SharedArray( + shape=(self.sample_rate * 10, len(self.channels)), + dtype=np.dtype("float32"), + ) + self.local_read_shared_array = self.read_shared_array.get_local_array() self.write_queue = process_Queue() - self.read_queue = process_Queue() # Reset events self.record_start_event.clear() @@ -271,7 +279,7 @@ class PortAudioMicrophone(Microphone): self.record_is_started_event, self.audio_callback_start_event, self.write_queue, - self.read_queue, + self.read_shared_array, self.sounddevice_sdk, ), ) @@ -297,7 +305,7 @@ class PortAudioMicrophone(Microphone): self.stop_recording() self.record_close_event.set() - self.read_queue.close() + self.read_shared_array.delete() self.write_queue.close() self.record_process.join() @@ -310,16 +318,7 @@ class PortAudioMicrophone(Microphone): """ Thread/Process-safe callback to read available audio data """ - audio_readings = np.empty((0, len(self.channels))) - - while True: - try: - audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0) - self.read_queue.task_done() - except Empty: - break - - return audio_readings + return self.read_shared_array.read(self.local_read_shared_array, flush=True) def read(self) -> np.ndarray: """ @@ -353,7 +352,7 @@ class PortAudioMicrophone(Microphone): record_is_started_event, audio_callback_start_event, write_queue, - read_queue, + read_shared_array, sounddevice_sdk, ) -> None: """ @@ -361,6 +360,7 @@ class PortAudioMicrophone(Microphone): """ channels_index = np.array(channels) - 1 + local_read_shared_array = read_shared_array.get_local_array() def audio_callback(indata, frames, timestamp, status) -> None: """ @@ -370,7 +370,7 @@ class PortAudioMicrophone(Microphone): logger.warning(status) if audio_callback_start_event.is_set(): write_queue.put_nowait(indata[:, channels_index]) - read_queue.put_nowait(indata[:, channels_index]) + read_shared_array.write(local_read_shared_array, indata[:, channels_index]) # Create the audio stream # InputStream must be instantiated in the process as it is not pickable. @@ -413,8 +413,8 @@ class PortAudioMicrophone(Microphone): if self.is_recording: raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.") - # Reset queues - self._clear_queue(self.read_queue) + # Reset queue and shared memory + self.read_shared_array.reset() self._clear_queue(self.write_queue) # Reset stop event @@ -491,10 +491,7 @@ class PortAudioMicrophone(Microphone): self.record_start_event.clear() # Ensures the audio stream is not started again ! self.record_stop_event.set() - while self.is_recording: - time.sleep(0.01) - - self._clear_queue(self.read_queue, join_queue=True) + self.read_shared_array.reset() self._clear_queue(self.write_queue, join_queue=True) if self.is_writing: diff --git a/src/lerobot/utils/shared_array.py b/src/lerobot/utils/shared_array.py new file mode 100644 index 000000000..20520c0d7 --- /dev/null +++ b/src/lerobot/utils/shared_array.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from multiprocessing import Lock, Value, shared_memory + +import numpy as np + + +class SharedArray: + """ + A SharedArray is a numpy array shared between multiple processes in a shared_memory object. + - Data is written to the array using the `write` method, which appends data to the array. + - Data is read from the array (and eventually flushed) using the `read` method, which copies the _whole_ array. + + SharedArray offers quasi-instantaneous array-wide read and flush capabilities in comparison to Queues, but has a limited size defined at initialization. + + Example: + _Main_process_ + shared_array = SharedArray(shape=(10, 10), dtype=np.dtype("float32")) + local_array = shared_array.get_local_array() + shared_array.write(local_array, np.array([[1, 2, 3], [4, 5, 6]])) + + _Child_process_ + local_array = shared_array.get_local_array() + data = shared_array.read(local_array, flush=True) + """ + + def __init__(self, shape: tuple[int], dtype: np.dtype | str): + """ + Initialize a SharedArray. + + Args: + shape: The shape of the shared array. + dtype: The dtype of the shared array. + """ + self.shape = shape + self.dtype = dtype + + self.shared_memory = shared_memory.SharedMemory( + create=True, size=np.prod(shape) * np.dtype(dtype).itemsize + ) + self.read_index = Value("i", 0) + self.lock = Lock() + + def get_local_array(self) -> np.ndarray: + """ + Get a process-local instance of the shared array. + + Returns: + A process-local instance of the shared array. + """ + return np.ndarray(self.shape, dtype=np.dtype(self.dtype), buffer=self.shared_memory.buf) + + def delete(self): + """ + Delete the shared array. + """ + self.shared_memory.close() + self.shared_memory.unlink() + + def write(self, local_array: np.ndarray, data: np.ndarray): + """ + Write data to the shared array. + + Args: + local_array: The process-local instance of the shared array to write to. + data: The data to write to the shared array. + """ + with self.lock: + local_array[self.read_index.value : self.read_index.value + len(data)] = data + self.read_index.value += len(data) + + def read(self, local_array: np.ndarray, flush: bool = True) -> np.ndarray: + """ + Read data from the shared array. + + Args: + local_array: The process-local instance of the shared array to read from. + flush: Whether to flush the shared array after reading. + """ + with self.lock: + data = np.copy(local_array[: self.read_index.value]) + if flush: + self.read_index.value = 0 + return data + + def reset(self): + """ + Reset the read index to 0. + """ + with self.lock: + self.read_index.value = 0 diff --git a/tests/microphones/test_shared_array.py b/tests/microphones/test_shared_array.py new file mode 100644 index 000000000..8f218193a --- /dev/null +++ b/tests/microphones/test_shared_array.py @@ -0,0 +1,508 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import time +from multiprocessing import Event, Process, Queue + +import numpy as np +import pytest + +from lerobot.utils.shared_array import SharedArray + + +def writer_process(shared_array, data_queue, stop_event, barrier, process_id): + """Writer process that continuously writes data to shared array.""" + local_array = shared_array.get_local_array() + + # Wait for all processes to be ready + barrier.wait() + + write_count = 0 + while not stop_event.is_set() and write_count < 10: + # Generate unique data for this process and write iteration + data = np.full((5, 2), process_id * 100 + write_count, dtype=np.float32) + + try: + shared_array.write(local_array, data) + data_queue.put(f"writer_{process_id}_wrote_{write_count}") + write_count += 1 + time.sleep(0.01) # Small delay to allow race conditions + except IndexError: + # Array is full, stop writing + break + + +def reader_process(shared_array, data_queue, stop_event, barrier, process_id): + """Reader process that continuously reads data from shared array.""" + local_array = shared_array.get_local_array() + + # Wait for all processes to be ready + barrier.wait() + + read_count = 0 + while not stop_event.is_set() and read_count < 5: + time.sleep(0.02) # Allow some writes to accumulate + + data = shared_array.read(local_array, flush=True) + data_queue.put(f"reader_{process_id}_read_{len(data)}_items") + read_count += 1 + + +def stress_writer_process(shared_array, data_queue, stop_event, barrier, process_id): + """High-frequency writer process for stress testing.""" + local_array = shared_array.get_local_array() + + barrier.wait() + + write_count = 0 + while not stop_event.is_set() and write_count < 50: + # Write single row at a time for more frequent operations + data = np.array([[process_id, write_count]], dtype=np.float32) + + try: + shared_array.write(local_array, data) + write_count += 1 + # No sleep - stress test + except IndexError: + break + + data_queue.put(f"stress_writer_{process_id}_completed_{write_count}") + + +# Basic functionality tests + + +def test_shared_array_creation(): + """Test basic SharedArray creation and properties.""" + shape = (100, 4) + dtype = np.float32 + + shared_array = SharedArray(shape=shape, dtype=dtype) + + assert shared_array.shape == shape + assert shared_array.dtype == dtype + assert shared_array.read_index.value == 0 + + # Clean up + shared_array.delete() + + +def test_local_array_access(): + """Test getting local array instances.""" + shape = (50, 2) + shared_array = SharedArray(shape=shape, dtype=np.float32) + + local_array = shared_array.get_local_array() + + assert local_array.shape == shape + assert local_array.dtype == np.float32 + assert isinstance(local_array, np.ndarray) + + # Test that we can get multiple local array instances + local_array2 = shared_array.get_local_array() + assert local_array2.shape == shape + + shared_array.delete() + + +def test_write_and_read_single_process(): + """Test basic write and read operations in single process.""" + shape = (20, 3) + shared_array = SharedArray(shape=shape, dtype=np.float32) + local_array = shared_array.get_local_array() + + # Write some data + data1 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + shared_array.write(local_array, data1) + + assert shared_array.read_index.value == 2 + + # Write more data + data2 = np.array([[7, 8, 9]], dtype=np.float32) + shared_array.write(local_array, data2) + + assert shared_array.read_index.value == 3 + + # Read all data + read_data = shared_array.read(local_array, flush=False) + expected = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32) + np.testing.assert_array_equal(read_data, expected) + + # Read with flush + read_data_flush = shared_array.read(local_array, flush=True) + np.testing.assert_array_equal(read_data_flush, expected) + assert shared_array.read_index.value == 0 + + shared_array.delete() + + +def test_array_overflow(): + """Test behavior when writing more data than array capacity.""" + shape = (5, 2) # Small array + shared_array = SharedArray(shape=shape, dtype=np.float32) + local_array = shared_array.get_local_array() + + # Fill the array + data = np.ones((5, 2), dtype=np.float32) + shared_array.write(local_array, data) + + # Try to write more data - should raise IndexError + with pytest.raises(ValueError): + extra_data = np.ones((2, 2), dtype=np.float32) + shared_array.write(local_array, extra_data) + + shared_array.delete() + + +def test_reset_functionality(): + """Test the reset method.""" + shape = (10, 2) + shared_array = SharedArray(shape=shape, dtype=np.float32) + local_array = shared_array.get_local_array() + + # Write some data + data = np.ones((3, 2), dtype=np.float32) + shared_array.write(local_array, data) + assert shared_array.read_index.value == 3 + + # Reset + shared_array.reset() + assert shared_array.read_index.value == 0 + + # Read should return empty array + read_data = shared_array.read(local_array, flush=False) + assert len(read_data) == 0 + + shared_array.delete() + + +# Multi-process tests + + +def test_single_writer_single_reader(): + """Test basic writer-reader scenario with one process each.""" + shape = (100, 2) + shared_array = SharedArray(shape=shape, dtype=np.float32) + + data_queue = Queue() + stop_event = Event() + barrier = multiprocessing.Barrier(2) # Writer + reader + + # Start writer process + writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1)) + + # Start reader process + reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1)) + + writer.start() + reader.start() + + # Let them run for a bit + time.sleep(0.5) + stop_event.set() + + # Wait for completion + writer.join(timeout=2.0) + reader.join(timeout=2.0) + + # Verify both processes completed + assert not writer.is_alive() + assert not reader.is_alive() + + # Check that we got messages from both processes + messages = [] + while not data_queue.empty(): + messages.append(data_queue.get()) + + writer_messages = [msg for msg in messages if msg.startswith("writer_")] + reader_messages = [msg for msg in messages if msg.startswith("reader_")] + + assert len(writer_messages) > 0 + assert len(reader_messages) > 0 + + shared_array.delete() + + +def test_multiple_writers_single_reader(): + """Test multiple writers with single reader - check for race conditions.""" + shape = (200, 2) + shared_array = SharedArray(shape=shape, dtype=np.float32) + + data_queue = Queue() + stop_event = Event() + num_writers = 3 + barrier = multiprocessing.Barrier(num_writers + 1) # Writers + reader + + processes = [] + + # Start multiple writer processes + for i in range(num_writers): + writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1)) + processes.append(writer) + writer.start() + + # Start reader process + reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1)) + processes.append(reader) + reader.start() + + # Let them run + time.sleep(1.0) + stop_event.set() + + # Wait for all processes + for process in processes: + process.join(timeout=3.0) + assert not process.is_alive() + + # Verify we got messages from all processes + messages = [] + while not data_queue.empty(): + messages.append(data_queue.get()) + + writer_messages = [msg for msg in messages if msg.startswith("writer_")] + reader_messages = [msg for msg in messages if msg.startswith("reader_")] + + # Should have messages from all writers + assert len(writer_messages) >= num_writers + assert len(reader_messages) > 0 + + shared_array.delete() + + +def test_data_integrity_with_concurrent_access(): + """Test that data integrity is maintained under concurrent access using standard reader/writer processes.""" + shape = (500, 2) # Use standard 2-column format + shared_array = SharedArray(shape=shape, dtype=np.float32) + + data_queue = Queue() + stop_event = Event() + barrier = multiprocessing.Barrier(3) # 2 writers + 1 reader + + # Start two writer processes + writer1 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1)) + writer2 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 2)) + + # Start one reader process + reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1)) + + writer1.start() + writer2.start() + reader.start() + + # Let them run for integrity test duration + time.sleep(1.0) + stop_event.set() + + # Wait for completion + writer1.join(timeout=3.0) + writer2.join(timeout=3.0) + reader.join(timeout=3.0) + + # Verify all processes completed successfully + assert not writer1.is_alive() + assert not writer2.is_alive() + assert not reader.is_alive() + + # Verify data integrity by checking messages + messages = [] + while not data_queue.empty(): + messages.append(data_queue.get()) + + writer1_messages = [msg for msg in messages if "writer_1_wrote" in msg] + writer2_messages = [msg for msg in messages if "writer_2_wrote" in msg] + reader_messages = [msg for msg in messages if "reader_1_read" in msg] + + # Verify both writers wrote data + assert len(writer1_messages) > 0 + assert len(writer2_messages) > 0 + # Verify reader read data + assert len(reader_messages) > 0 + + # Verify the shared array is in a consistent state + local_array = shared_array.get_local_array() + final_data = shared_array.read(local_array, flush=False) + + # Should have some data written by the writers + assert len(final_data) >= 0 # Could be empty if reader flushed everything + # Should not exceed array capacity + assert len(final_data) <= shape[0] + + # If there's data, verify it contains the expected writer signatures + if len(final_data) > 0: + # Data should contain values like 100, 101, 102... (writer 1) or 200, 201, 202... (writer 2) + unique_values = np.unique(final_data.flatten()) + writer1_values = unique_values[(unique_values >= 100) & (unique_values < 200)] + writer2_values = unique_values[(unique_values >= 200) & (unique_values < 300)] + + # Should have data from at least one writer + assert len(writer1_values) > 0 or len(writer2_values) > 0 + + shared_array.delete() + + +def test_stress_test_high_frequency_operations(): + """Stress test with high frequency read/write operations.""" + shape = (1000, 2) + shared_array = SharedArray(shape=shape, dtype=np.float32) + + data_queue = Queue() + stop_event = Event() + num_writers = 4 + barrier = multiprocessing.Barrier(num_writers) + + processes = [] + + # Start multiple high-frequency writers + for i in range(num_writers): + writer = Process( + target=stress_writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1) + ) + processes.append(writer) + writer.start() + + # Let them run for stress test duration + time.sleep(0.5) + stop_event.set() + + # Wait for completion + for process in processes: + process.join(timeout=3.0) + assert not process.is_alive() + + # Verify all writers completed successfully + messages = [] + while not data_queue.empty(): + messages.append(data_queue.get()) + + completed_messages = [msg for msg in messages if "completed" in msg] + assert len(completed_messages) == num_writers + + # Verify the shared array is in a consistent state + local_array = shared_array.get_local_array() + final_data = shared_array.read(local_array, flush=False) + + # Should have some data written + assert len(final_data) > 0 + # Should not exceed array capacity + assert len(final_data) <= shape[0] + + shared_array.delete() + + +def test_concurrent_readers(): + """Test multiple concurrent readers with writers to ensure thread safety.""" + shape = (200, 2) + shared_array = SharedArray(shape=shape, dtype=np.float32) + + data_queue = Queue() + stop_event = Event() + num_readers = 3 + num_writers = 2 + barrier = multiprocessing.Barrier(num_readers + num_writers) + + processes = [] + + # Start multiple writer processes to generate data + for i in range(num_writers): + writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1)) + processes.append(writer) + writer.start() + + # Start multiple reader processes + for i in range(num_readers): + reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, i + 1)) + processes.append(reader) + reader.start() + + # Let them run to test concurrent access + time.sleep(1.0) + stop_event.set() + + # Wait for all processes to complete + for process in processes: + process.join(timeout=3.0) + assert not process.is_alive() + + # Verify all readers and writers completed + messages = [] + while not data_queue.empty(): + messages.append(data_queue.get()) + + reader_messages = [msg for msg in messages if msg.startswith("reader_")] + writer_messages = [msg for msg in messages if msg.startswith("writer_")] + + # Should have messages from all readers and writers + assert len(reader_messages) >= num_readers + assert len(writer_messages) >= num_writers + + # Verify different readers generated different messages (proving they ran concurrently) + reader_ids = set() + for msg in reader_messages: + # Extract reader ID from message like "reader_1_read_5_items" + parts = msg.split("_") + if len(parts) >= 2: + reader_ids.add(parts[1]) + + assert len(reader_ids) == num_readers # All readers should have participated + + shared_array.delete() + + +def test_edge_case_empty_reads(): + """Test reading from empty array and after flushes.""" + shape = (10, 2) + shared_array = SharedArray(shape=shape, dtype=np.float32) + local_array = shared_array.get_local_array() + + # Read from empty array + empty_data = shared_array.read(local_array, flush=False) + assert len(empty_data) == 0 + + # Write some data + data = np.ones((3, 2), dtype=np.float32) + shared_array.write(local_array, data) + + # Read with flush + read_data = shared_array.read(local_array, flush=True) + assert len(read_data) == 3 + + # Read again after flush - should be empty + empty_again = shared_array.read(local_array, flush=False) + assert len(empty_again) == 0 + + shared_array.delete() + + +def test_different_dtypes(): + """Test SharedArray with different numpy dtypes.""" + dtypes_to_test = [np.float32, np.float64, np.int32, np.int16] + + for dtype in dtypes_to_test: + shape = (20, 2) + shared_array = SharedArray(shape=shape, dtype=dtype) + local_array = shared_array.get_local_array() + + assert local_array.dtype == dtype + + # Write and read data of this dtype + data = np.ones((5, 2), dtype=dtype) + shared_array.write(local_array, data) + + read_data = shared_array.read(local_array, flush=True) + assert read_data.dtype == dtype + assert len(read_data) == 5 + + shared_array.delete()