fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample() (#3372)

This commit is contained in:
Khalil Meftah
2026-04-14 13:16:45 +02:00
committed by GitHub
parent b3e76a92f2
commit d57c58a532
+17 -21
View File
@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import functools import functools
import threading
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from contextlib import suppress from contextlib import suppress
from typing import TypedDict from typing import TypedDict
@@ -115,6 +116,7 @@ class ReplayBuffer:
self.size = 0 self.size = 0
self.initialized = False self.initialized = False
self.optimize_memory = optimize_memory self.optimize_memory = optimize_memory
self._lock = threading.Lock()
# Track episode boundaries for memory optimization # Track episode boundaries for memory optimization
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device) self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
@@ -198,6 +200,7 @@ class ReplayBuffer:
complementary_info: dict[str, torch.Tensor] | None = None, complementary_info: dict[str, torch.Tensor] | None = None,
): ):
"""Saves a transition, ensuring tensors are stored on the designated storage device.""" """Saves a transition, ensuring tensors are stored on the designated storage device."""
with self._lock:
# Initialize storage if this is the first transition # Initialize storage if this is the first transition
if not self.initialized: if not self.initialized:
self._initialize_storage(state=state, action=action, complementary_info=complementary_info) self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
@@ -217,7 +220,6 @@ class ReplayBuffer:
# Handle complementary_info if provided and storage is initialized # Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info: if complementary_info is not None and self.has_complementary_info:
# Store the complementary_info
for key in self.complementary_info_keys: for key in self.complementary_info_keys:
if key in complementary_info: if key in complementary_info:
value = complementary_info[key] value = complementary_info[key]
@@ -234,32 +236,39 @@ class ReplayBuffer:
if not self.initialized: if not self.initialized:
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
with self._lock:
batch_size = min(batch_size, self.size) batch_size = min(batch_size, self.size)
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
# Random indices for sampling - create on the same device as storage
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else [] image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
# Create batched state and next_state
batch_state = {} batch_state = {}
batch_next_state = {} batch_next_state = {}
# First pass: load all state tensors to target device
for key in self.states: for key in self.states:
batch_state[key] = self.states[key][idx].to(self.device) batch_state[key] = self.states[key][idx].to(self.device)
if not self.optimize_memory: if not self.optimize_memory:
# Standard approach - load next_states directly
batch_next_state[key] = self.next_states[key][idx].to(self.device) batch_next_state[key] = self.next_states[key][idx].to(self.device)
else: else:
# Memory-optimized approach - get next_state from the next index
next_idx = (idx + 1) % self.capacity next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device) batch_next_state[key] = self.states[key][next_idx].to(self.device)
# Apply image augmentation in a batched way if needed # Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
if self.use_drq and image_keys: if self.use_drq and image_keys:
# Concatenate all images from state and next_state # Concatenate all images from state and next_state
all_images = [] all_images = []
@@ -280,19 +289,6 @@ class ReplayBuffer:
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots # Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
return BatchTransition( return BatchTransition(
state=batch_state, state=batch_state,
action=batch_actions, action=batch_actions,