mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
feat(shared array): removing queues copy and flush delays with a SharedArray inter-process communication
This commit is contained in:
@@ -31,13 +31,14 @@ from typing import Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from soundfile import SoundFile
|
from soundfile import SoundFile
|
||||||
|
|
||||||
|
from lerobot.microphones.portaudio.interface_sounddevice_sdk import ISounddeviceSDK, SounddeviceSDKAdapter
|
||||||
from lerobot.utils.errors import (
|
from lerobot.utils.errors import (
|
||||||
DeviceAlreadyConnectedError,
|
DeviceAlreadyConnectedError,
|
||||||
DeviceAlreadyRecordingError,
|
DeviceAlreadyRecordingError,
|
||||||
DeviceNotConnectedError,
|
DeviceNotConnectedError,
|
||||||
DeviceNotRecordingError,
|
DeviceNotRecordingError,
|
||||||
)
|
)
|
||||||
from lerobot.microphones.portaudio.interface_sounddevice_sdk import ISounddeviceSDK, SounddeviceSDKAdapter
|
from lerobot.utils.shared_array import SharedArray
|
||||||
|
|
||||||
from ..microphone import Microphone
|
from ..microphone import Microphone
|
||||||
from .configuration_portaudio import PortAudioMicrophoneConfig
|
from .configuration_portaudio import PortAudioMicrophoneConfig
|
||||||
@@ -93,10 +94,13 @@ class PortAudioMicrophone(Microphone):
|
|||||||
self.record_is_started_event = process_Event()
|
self.record_is_started_event = process_Event()
|
||||||
self.audio_callback_start_event = process_Event()
|
self.audio_callback_start_event = process_Event()
|
||||||
|
|
||||||
# Process-safe concurrent queues to store the written/read audio
|
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||||
|
# TODO(CarolinePascal): replace by a Pipe (more efficient !)
|
||||||
self.write_queue = process_Queue()
|
self.write_queue = process_Queue()
|
||||||
self.read_queue = process_Queue()
|
|
||||||
|
|
||||||
|
# SharedArray to store audio from the recording process.
|
||||||
|
self.read_shared_array = None
|
||||||
|
self.local_read_shared_array = None
|
||||||
# Thread/Process to handle data writing in a separate thread/process (safely)
|
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||||
self.write_thread = None
|
self.write_thread = None
|
||||||
self.write_stop_event = None
|
self.write_stop_event = None
|
||||||
@@ -246,9 +250,13 @@ class PortAudioMicrophone(Microphone):
|
|||||||
|
|
||||||
self._configure_capture_settings()
|
self._configure_capture_settings()
|
||||||
|
|
||||||
# Create or reset queues
|
# Create or reset queue and shared array
|
||||||
|
self.read_shared_array = SharedArray(
|
||||||
|
shape=(self.sample_rate * 10, len(self.channels)),
|
||||||
|
dtype=np.dtype("float32"),
|
||||||
|
)
|
||||||
|
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||||
self.write_queue = process_Queue()
|
self.write_queue = process_Queue()
|
||||||
self.read_queue = process_Queue()
|
|
||||||
|
|
||||||
# Reset events
|
# Reset events
|
||||||
self.record_start_event.clear()
|
self.record_start_event.clear()
|
||||||
@@ -271,7 +279,7 @@ class PortAudioMicrophone(Microphone):
|
|||||||
self.record_is_started_event,
|
self.record_is_started_event,
|
||||||
self.audio_callback_start_event,
|
self.audio_callback_start_event,
|
||||||
self.write_queue,
|
self.write_queue,
|
||||||
self.read_queue,
|
self.read_shared_array,
|
||||||
self.sounddevice_sdk,
|
self.sounddevice_sdk,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -297,7 +305,7 @@ class PortAudioMicrophone(Microphone):
|
|||||||
self.stop_recording()
|
self.stop_recording()
|
||||||
|
|
||||||
self.record_close_event.set()
|
self.record_close_event.set()
|
||||||
self.read_queue.close()
|
self.read_shared_array.delete()
|
||||||
self.write_queue.close()
|
self.write_queue.close()
|
||||||
self.record_process.join()
|
self.record_process.join()
|
||||||
|
|
||||||
@@ -310,16 +318,7 @@ class PortAudioMicrophone(Microphone):
|
|||||||
"""
|
"""
|
||||||
Thread/Process-safe callback to read available audio data
|
Thread/Process-safe callback to read available audio data
|
||||||
"""
|
"""
|
||||||
audio_readings = np.empty((0, len(self.channels)))
|
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
audio_readings = np.concatenate((audio_readings, self.read_queue.get_nowait()), axis=0)
|
|
||||||
self.read_queue.task_done()
|
|
||||||
except Empty:
|
|
||||||
break
|
|
||||||
|
|
||||||
return audio_readings
|
|
||||||
|
|
||||||
def read(self) -> np.ndarray:
|
def read(self) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@@ -353,7 +352,7 @@ class PortAudioMicrophone(Microphone):
|
|||||||
record_is_started_event,
|
record_is_started_event,
|
||||||
audio_callback_start_event,
|
audio_callback_start_event,
|
||||||
write_queue,
|
write_queue,
|
||||||
read_queue,
|
read_shared_array,
|
||||||
sounddevice_sdk,
|
sounddevice_sdk,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -361,6 +360,7 @@ class PortAudioMicrophone(Microphone):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
channels_index = np.array(channels) - 1
|
channels_index = np.array(channels) - 1
|
||||||
|
local_read_shared_array = read_shared_array.get_local_array()
|
||||||
|
|
||||||
def audio_callback(indata, frames, timestamp, status) -> None:
|
def audio_callback(indata, frames, timestamp, status) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -370,7 +370,7 @@ class PortAudioMicrophone(Microphone):
|
|||||||
logger.warning(status)
|
logger.warning(status)
|
||||||
if audio_callback_start_event.is_set():
|
if audio_callback_start_event.is_set():
|
||||||
write_queue.put_nowait(indata[:, channels_index])
|
write_queue.put_nowait(indata[:, channels_index])
|
||||||
read_queue.put_nowait(indata[:, channels_index])
|
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||||
|
|
||||||
# Create the audio stream
|
# Create the audio stream
|
||||||
# InputStream must be instantiated in the process as it is not pickable.
|
# InputStream must be instantiated in the process as it is not pickable.
|
||||||
@@ -413,8 +413,8 @@ class PortAudioMicrophone(Microphone):
|
|||||||
if self.is_recording:
|
if self.is_recording:
|
||||||
raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
|
raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
|
||||||
|
|
||||||
# Reset queues
|
# Reset queue and shared memory
|
||||||
self._clear_queue(self.read_queue)
|
self.read_shared_array.reset()
|
||||||
self._clear_queue(self.write_queue)
|
self._clear_queue(self.write_queue)
|
||||||
|
|
||||||
# Reset stop event
|
# Reset stop event
|
||||||
@@ -491,10 +491,7 @@ class PortAudioMicrophone(Microphone):
|
|||||||
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||||
self.record_stop_event.set()
|
self.record_stop_event.set()
|
||||||
|
|
||||||
while self.is_recording:
|
self.read_shared_array.reset()
|
||||||
time.sleep(0.01)
|
|
||||||
|
|
||||||
self._clear_queue(self.read_queue, join_queue=True)
|
|
||||||
self._clear_queue(self.write_queue, join_queue=True)
|
self._clear_queue(self.write_queue, join_queue=True)
|
||||||
|
|
||||||
if self.is_writing:
|
if self.is_writing:
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from multiprocessing import Lock, Value, shared_memory
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class SharedArray:
|
||||||
|
"""
|
||||||
|
A SharedArray is a numpy array shared between multiple processes in a shared_memory object.
|
||||||
|
- Data is written to the array using the `write` method, which appends data to the array.
|
||||||
|
- Data is read from the array (and eventually flushed) using the `read` method, which copies the _whole_ array.
|
||||||
|
|
||||||
|
SharedArray offers quasi-instantaneous array-wide read and flush capabilities in comparison to Queues, but has a limited size defined at initialization.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
_Main_process_
|
||||||
|
shared_array = SharedArray(shape=(10, 10), dtype=np.dtype("float32"))
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
shared_array.write(local_array, np.array([[1, 2, 3], [4, 5, 6]]))
|
||||||
|
|
||||||
|
_Child_process_
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
data = shared_array.read(local_array, flush=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shape: tuple[int], dtype: np.dtype | str):
|
||||||
|
"""
|
||||||
|
Initialize a SharedArray.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the shared array.
|
||||||
|
dtype: The dtype of the shared array.
|
||||||
|
"""
|
||||||
|
self.shape = shape
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.shared_memory = shared_memory.SharedMemory(
|
||||||
|
create=True, size=np.prod(shape) * np.dtype(dtype).itemsize
|
||||||
|
)
|
||||||
|
self.read_index = Value("i", 0)
|
||||||
|
self.lock = Lock()
|
||||||
|
|
||||||
|
def get_local_array(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get a process-local instance of the shared array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A process-local instance of the shared array.
|
||||||
|
"""
|
||||||
|
return np.ndarray(self.shape, dtype=np.dtype(self.dtype), buffer=self.shared_memory.buf)
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
"""
|
||||||
|
Delete the shared array.
|
||||||
|
"""
|
||||||
|
self.shared_memory.close()
|
||||||
|
self.shared_memory.unlink()
|
||||||
|
|
||||||
|
def write(self, local_array: np.ndarray, data: np.ndarray):
|
||||||
|
"""
|
||||||
|
Write data to the shared array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_array: The process-local instance of the shared array to write to.
|
||||||
|
data: The data to write to the shared array.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
local_array[self.read_index.value : self.read_index.value + len(data)] = data
|
||||||
|
self.read_index.value += len(data)
|
||||||
|
|
||||||
|
def read(self, local_array: np.ndarray, flush: bool = True) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Read data from the shared array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_array: The process-local instance of the shared array to read from.
|
||||||
|
flush: Whether to flush the shared array after reading.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
data = np.copy(local_array[: self.read_index.value])
|
||||||
|
if flush:
|
||||||
|
self.read_index.value = 0
|
||||||
|
return data
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the read index to 0.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
self.read_index.value = 0
|
||||||
@@ -0,0 +1,508 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
import time
|
||||||
|
from multiprocessing import Event, Process, Queue
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.utils.shared_array import SharedArray
|
||||||
|
|
||||||
|
|
||||||
|
def writer_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||||
|
"""Writer process that continuously writes data to shared array."""
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Wait for all processes to be ready
|
||||||
|
barrier.wait()
|
||||||
|
|
||||||
|
write_count = 0
|
||||||
|
while not stop_event.is_set() and write_count < 10:
|
||||||
|
# Generate unique data for this process and write iteration
|
||||||
|
data = np.full((5, 2), process_id * 100 + write_count, dtype=np.float32)
|
||||||
|
|
||||||
|
try:
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
data_queue.put(f"writer_{process_id}_wrote_{write_count}")
|
||||||
|
write_count += 1
|
||||||
|
time.sleep(0.01) # Small delay to allow race conditions
|
||||||
|
except IndexError:
|
||||||
|
# Array is full, stop writing
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def reader_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||||
|
"""Reader process that continuously reads data from shared array."""
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Wait for all processes to be ready
|
||||||
|
barrier.wait()
|
||||||
|
|
||||||
|
read_count = 0
|
||||||
|
while not stop_event.is_set() and read_count < 5:
|
||||||
|
time.sleep(0.02) # Allow some writes to accumulate
|
||||||
|
|
||||||
|
data = shared_array.read(local_array, flush=True)
|
||||||
|
data_queue.put(f"reader_{process_id}_read_{len(data)}_items")
|
||||||
|
read_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
def stress_writer_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||||
|
"""High-frequency writer process for stress testing."""
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
barrier.wait()
|
||||||
|
|
||||||
|
write_count = 0
|
||||||
|
while not stop_event.is_set() and write_count < 50:
|
||||||
|
# Write single row at a time for more frequent operations
|
||||||
|
data = np.array([[process_id, write_count]], dtype=np.float32)
|
||||||
|
|
||||||
|
try:
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
write_count += 1
|
||||||
|
# No sleep - stress test
|
||||||
|
except IndexError:
|
||||||
|
break
|
||||||
|
|
||||||
|
data_queue.put(f"stress_writer_{process_id}_completed_{write_count}")
|
||||||
|
|
||||||
|
|
||||||
|
# Basic functionality tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_shared_array_creation():
|
||||||
|
"""Test basic SharedArray creation and properties."""
|
||||||
|
shape = (100, 4)
|
||||||
|
dtype = np.float32
|
||||||
|
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=dtype)
|
||||||
|
|
||||||
|
assert shared_array.shape == shape
|
||||||
|
assert shared_array.dtype == dtype
|
||||||
|
assert shared_array.read_index.value == 0
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_array_access():
|
||||||
|
"""Test getting local array instances."""
|
||||||
|
shape = (50, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
assert local_array.shape == shape
|
||||||
|
assert local_array.dtype == np.float32
|
||||||
|
assert isinstance(local_array, np.ndarray)
|
||||||
|
|
||||||
|
# Test that we can get multiple local array instances
|
||||||
|
local_array2 = shared_array.get_local_array()
|
||||||
|
assert local_array2.shape == shape
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_and_read_single_process():
|
||||||
|
"""Test basic write and read operations in single process."""
|
||||||
|
shape = (20, 3)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Write some data
|
||||||
|
data1 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
||||||
|
shared_array.write(local_array, data1)
|
||||||
|
|
||||||
|
assert shared_array.read_index.value == 2
|
||||||
|
|
||||||
|
# Write more data
|
||||||
|
data2 = np.array([[7, 8, 9]], dtype=np.float32)
|
||||||
|
shared_array.write(local_array, data2)
|
||||||
|
|
||||||
|
assert shared_array.read_index.value == 3
|
||||||
|
|
||||||
|
# Read all data
|
||||||
|
read_data = shared_array.read(local_array, flush=False)
|
||||||
|
expected = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
|
||||||
|
np.testing.assert_array_equal(read_data, expected)
|
||||||
|
|
||||||
|
# Read with flush
|
||||||
|
read_data_flush = shared_array.read(local_array, flush=True)
|
||||||
|
np.testing.assert_array_equal(read_data_flush, expected)
|
||||||
|
assert shared_array.read_index.value == 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_array_overflow():
|
||||||
|
"""Test behavior when writing more data than array capacity."""
|
||||||
|
shape = (5, 2) # Small array
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Fill the array
|
||||||
|
data = np.ones((5, 2), dtype=np.float32)
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
|
||||||
|
# Try to write more data - should raise IndexError
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
extra_data = np.ones((2, 2), dtype=np.float32)
|
||||||
|
shared_array.write(local_array, extra_data)
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_functionality():
|
||||||
|
"""Test the reset method."""
|
||||||
|
shape = (10, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Write some data
|
||||||
|
data = np.ones((3, 2), dtype=np.float32)
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
assert shared_array.read_index.value == 3
|
||||||
|
|
||||||
|
# Reset
|
||||||
|
shared_array.reset()
|
||||||
|
assert shared_array.read_index.value == 0
|
||||||
|
|
||||||
|
# Read should return empty array
|
||||||
|
read_data = shared_array.read(local_array, flush=False)
|
||||||
|
assert len(read_data) == 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
# Multi-process tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_writer_single_reader():
|
||||||
|
"""Test basic writer-reader scenario with one process each."""
|
||||||
|
shape = (100, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = Event()
|
||||||
|
barrier = multiprocessing.Barrier(2) # Writer + reader
|
||||||
|
|
||||||
|
# Start writer process
|
||||||
|
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||||
|
|
||||||
|
# Start reader process
|
||||||
|
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||||
|
|
||||||
|
writer.start()
|
||||||
|
reader.start()
|
||||||
|
|
||||||
|
# Let them run for a bit
|
||||||
|
time.sleep(0.5)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
writer.join(timeout=2.0)
|
||||||
|
reader.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Verify both processes completed
|
||||||
|
assert not writer.is_alive()
|
||||||
|
assert not reader.is_alive()
|
||||||
|
|
||||||
|
# Check that we got messages from both processes
|
||||||
|
messages = []
|
||||||
|
while not data_queue.empty():
|
||||||
|
messages.append(data_queue.get())
|
||||||
|
|
||||||
|
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||||
|
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||||
|
|
||||||
|
assert len(writer_messages) > 0
|
||||||
|
assert len(reader_messages) > 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_writers_single_reader():
|
||||||
|
"""Test multiple writers with single reader - check for race conditions."""
|
||||||
|
shape = (200, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = Event()
|
||||||
|
num_writers = 3
|
||||||
|
barrier = multiprocessing.Barrier(num_writers + 1) # Writers + reader
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
# Start multiple writer processes
|
||||||
|
for i in range(num_writers):
|
||||||
|
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||||
|
processes.append(writer)
|
||||||
|
writer.start()
|
||||||
|
|
||||||
|
# Start reader process
|
||||||
|
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||||
|
processes.append(reader)
|
||||||
|
reader.start()
|
||||||
|
|
||||||
|
# Let them run
|
||||||
|
time.sleep(1.0)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
# Wait for all processes
|
||||||
|
for process in processes:
|
||||||
|
process.join(timeout=3.0)
|
||||||
|
assert not process.is_alive()
|
||||||
|
|
||||||
|
# Verify we got messages from all processes
|
||||||
|
messages = []
|
||||||
|
while not data_queue.empty():
|
||||||
|
messages.append(data_queue.get())
|
||||||
|
|
||||||
|
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||||
|
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||||
|
|
||||||
|
# Should have messages from all writers
|
||||||
|
assert len(writer_messages) >= num_writers
|
||||||
|
assert len(reader_messages) > 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_integrity_with_concurrent_access():
|
||||||
|
"""Test that data integrity is maintained under concurrent access using standard reader/writer processes."""
|
||||||
|
shape = (500, 2) # Use standard 2-column format
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = Event()
|
||||||
|
barrier = multiprocessing.Barrier(3) # 2 writers + 1 reader
|
||||||
|
|
||||||
|
# Start two writer processes
|
||||||
|
writer1 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||||
|
writer2 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 2))
|
||||||
|
|
||||||
|
# Start one reader process
|
||||||
|
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||||
|
|
||||||
|
writer1.start()
|
||||||
|
writer2.start()
|
||||||
|
reader.start()
|
||||||
|
|
||||||
|
# Let them run for integrity test duration
|
||||||
|
time.sleep(1.0)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
writer1.join(timeout=3.0)
|
||||||
|
writer2.join(timeout=3.0)
|
||||||
|
reader.join(timeout=3.0)
|
||||||
|
|
||||||
|
# Verify all processes completed successfully
|
||||||
|
assert not writer1.is_alive()
|
||||||
|
assert not writer2.is_alive()
|
||||||
|
assert not reader.is_alive()
|
||||||
|
|
||||||
|
# Verify data integrity by checking messages
|
||||||
|
messages = []
|
||||||
|
while not data_queue.empty():
|
||||||
|
messages.append(data_queue.get())
|
||||||
|
|
||||||
|
writer1_messages = [msg for msg in messages if "writer_1_wrote" in msg]
|
||||||
|
writer2_messages = [msg for msg in messages if "writer_2_wrote" in msg]
|
||||||
|
reader_messages = [msg for msg in messages if "reader_1_read" in msg]
|
||||||
|
|
||||||
|
# Verify both writers wrote data
|
||||||
|
assert len(writer1_messages) > 0
|
||||||
|
assert len(writer2_messages) > 0
|
||||||
|
# Verify reader read data
|
||||||
|
assert len(reader_messages) > 0
|
||||||
|
|
||||||
|
# Verify the shared array is in a consistent state
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
final_data = shared_array.read(local_array, flush=False)
|
||||||
|
|
||||||
|
# Should have some data written by the writers
|
||||||
|
assert len(final_data) >= 0 # Could be empty if reader flushed everything
|
||||||
|
# Should not exceed array capacity
|
||||||
|
assert len(final_data) <= shape[0]
|
||||||
|
|
||||||
|
# If there's data, verify it contains the expected writer signatures
|
||||||
|
if len(final_data) > 0:
|
||||||
|
# Data should contain values like 100, 101, 102... (writer 1) or 200, 201, 202... (writer 2)
|
||||||
|
unique_values = np.unique(final_data.flatten())
|
||||||
|
writer1_values = unique_values[(unique_values >= 100) & (unique_values < 200)]
|
||||||
|
writer2_values = unique_values[(unique_values >= 200) & (unique_values < 300)]
|
||||||
|
|
||||||
|
# Should have data from at least one writer
|
||||||
|
assert len(writer1_values) > 0 or len(writer2_values) > 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stress_test_high_frequency_operations():
|
||||||
|
"""Stress test with high frequency read/write operations."""
|
||||||
|
shape = (1000, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = Event()
|
||||||
|
num_writers = 4
|
||||||
|
barrier = multiprocessing.Barrier(num_writers)
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
# Start multiple high-frequency writers
|
||||||
|
for i in range(num_writers):
|
||||||
|
writer = Process(
|
||||||
|
target=stress_writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1)
|
||||||
|
)
|
||||||
|
processes.append(writer)
|
||||||
|
writer.start()
|
||||||
|
|
||||||
|
# Let them run for stress test duration
|
||||||
|
time.sleep(0.5)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
for process in processes:
|
||||||
|
process.join(timeout=3.0)
|
||||||
|
assert not process.is_alive()
|
||||||
|
|
||||||
|
# Verify all writers completed successfully
|
||||||
|
messages = []
|
||||||
|
while not data_queue.empty():
|
||||||
|
messages.append(data_queue.get())
|
||||||
|
|
||||||
|
completed_messages = [msg for msg in messages if "completed" in msg]
|
||||||
|
assert len(completed_messages) == num_writers
|
||||||
|
|
||||||
|
# Verify the shared array is in a consistent state
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
final_data = shared_array.read(local_array, flush=False)
|
||||||
|
|
||||||
|
# Should have some data written
|
||||||
|
assert len(final_data) > 0
|
||||||
|
# Should not exceed array capacity
|
||||||
|
assert len(final_data) <= shape[0]
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_readers():
|
||||||
|
"""Test multiple concurrent readers with writers to ensure thread safety."""
|
||||||
|
shape = (200, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
|
||||||
|
data_queue = Queue()
|
||||||
|
stop_event = Event()
|
||||||
|
num_readers = 3
|
||||||
|
num_writers = 2
|
||||||
|
barrier = multiprocessing.Barrier(num_readers + num_writers)
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
# Start multiple writer processes to generate data
|
||||||
|
for i in range(num_writers):
|
||||||
|
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||||
|
processes.append(writer)
|
||||||
|
writer.start()
|
||||||
|
|
||||||
|
# Start multiple reader processes
|
||||||
|
for i in range(num_readers):
|
||||||
|
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||||
|
processes.append(reader)
|
||||||
|
reader.start()
|
||||||
|
|
||||||
|
# Let them run to test concurrent access
|
||||||
|
time.sleep(1.0)
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
# Wait for all processes to complete
|
||||||
|
for process in processes:
|
||||||
|
process.join(timeout=3.0)
|
||||||
|
assert not process.is_alive()
|
||||||
|
|
||||||
|
# Verify all readers and writers completed
|
||||||
|
messages = []
|
||||||
|
while not data_queue.empty():
|
||||||
|
messages.append(data_queue.get())
|
||||||
|
|
||||||
|
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||||
|
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||||
|
|
||||||
|
# Should have messages from all readers and writers
|
||||||
|
assert len(reader_messages) >= num_readers
|
||||||
|
assert len(writer_messages) >= num_writers
|
||||||
|
|
||||||
|
# Verify different readers generated different messages (proving they ran concurrently)
|
||||||
|
reader_ids = set()
|
||||||
|
for msg in reader_messages:
|
||||||
|
# Extract reader ID from message like "reader_1_read_5_items"
|
||||||
|
parts = msg.split("_")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
reader_ids.add(parts[1])
|
||||||
|
|
||||||
|
assert len(reader_ids) == num_readers # All readers should have participated
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_edge_case_empty_reads():
|
||||||
|
"""Test reading from empty array and after flushes."""
|
||||||
|
shape = (10, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
# Read from empty array
|
||||||
|
empty_data = shared_array.read(local_array, flush=False)
|
||||||
|
assert len(empty_data) == 0
|
||||||
|
|
||||||
|
# Write some data
|
||||||
|
data = np.ones((3, 2), dtype=np.float32)
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
|
||||||
|
# Read with flush
|
||||||
|
read_data = shared_array.read(local_array, flush=True)
|
||||||
|
assert len(read_data) == 3
|
||||||
|
|
||||||
|
# Read again after flush - should be empty
|
||||||
|
empty_again = shared_array.read(local_array, flush=False)
|
||||||
|
assert len(empty_again) == 0
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_dtypes():
|
||||||
|
"""Test SharedArray with different numpy dtypes."""
|
||||||
|
dtypes_to_test = [np.float32, np.float64, np.int32, np.int16]
|
||||||
|
|
||||||
|
for dtype in dtypes_to_test:
|
||||||
|
shape = (20, 2)
|
||||||
|
shared_array = SharedArray(shape=shape, dtype=dtype)
|
||||||
|
local_array = shared_array.get_local_array()
|
||||||
|
|
||||||
|
assert local_array.dtype == dtype
|
||||||
|
|
||||||
|
# Write and read data of this dtype
|
||||||
|
data = np.ones((5, 2), dtype=dtype)
|
||||||
|
shared_array.write(local_array, data)
|
||||||
|
|
||||||
|
read_data = shared_array.read(local_array, flush=True)
|
||||||
|
assert read_data.dtype == dtype
|
||||||
|
assert len(read_data) == 5
|
||||||
|
|
||||||
|
shared_array.delete()
|
||||||
Reference in New Issue
Block a user