fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample() (#3372)

This commit is contained in:
Khalil Meftah
2026-04-14 13:16:45 +02:00
committed by GitHub
parent b3e76a92f2
commit d57c58a532
+54 -58
View File
@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import functools import functools
import threading
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from contextlib import suppress from contextlib import suppress
from typing import TypedDict from typing import TypedDict
@@ -115,6 +116,7 @@ class ReplayBuffer:
self.size = 0 self.size = 0
self.initialized = False self.initialized = False
self.optimize_memory = optimize_memory self.optimize_memory = optimize_memory
self._lock = threading.Lock()
# Track episode boundaries for memory optimization # Track episode boundaries for memory optimization
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device) self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
@@ -198,68 +200,75 @@ class ReplayBuffer:
complementary_info: dict[str, torch.Tensor] | None = None, complementary_info: dict[str, torch.Tensor] | None = None,
): ):
"""Saves a transition, ensuring tensors are stored on the designated storage device.""" """Saves a transition, ensuring tensors are stored on the designated storage device."""
# Initialize storage if this is the first transition with self._lock:
if not self.initialized: # Initialize storage if this is the first transition
self._initialize_storage(state=state, action=action, complementary_info=complementary_info) if not self.initialized:
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
# Store the transition in pre-allocated tensors # Store the transition in pre-allocated tensors
for key in self.states: for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0)) self.states[key][self.position].copy_(state[key].squeeze(dim=0))
if not self.optimize_memory: if not self.optimize_memory:
# Only store next_states if not optimizing memory # Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
self.actions[self.position].copy_(action.squeeze(dim=0)) self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward self.rewards[self.position] = reward
self.dones[self.position] = done self.dones[self.position] = done
self.truncateds[self.position] = truncated self.truncateds[self.position] = truncated
# Handle complementary_info if provided and storage is initialized # Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info: if complementary_info is not None and self.has_complementary_info:
# Store the complementary_info for key in self.complementary_info_keys:
for key in self.complementary_info_keys: if key in complementary_info:
if key in complementary_info: value = complementary_info[key]
value = complementary_info[key] if isinstance(value, torch.Tensor):
if isinstance(value, torch.Tensor): self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) elif isinstance(value, (int | float)):
elif isinstance(value, (int | float)): self.complementary_info[key][self.position] = value
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity) self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size: int) -> BatchTransition: def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors.""" """Sample a random batch of transitions and collate them into batched tensors."""
if not self.initialized: if not self.initialized:
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
batch_size = min(batch_size, self.size) with self._lock:
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size batch_size = min(batch_size, self.size)
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
# Random indices for sampling - create on the same device as storage idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
# Create batched state and next_state batch_state = {}
batch_state = {} batch_next_state = {}
batch_next_state = {}
# First pass: load all state tensors to target device for key in self.states:
for key in self.states: batch_state[key] = self.states[key][idx].to(self.device)
batch_state[key] = self.states[key][idx].to(self.device)
if not self.optimize_memory: if not self.optimize_memory:
# Standard approach - load next_states directly batch_next_state[key] = self.next_states[key][idx].to(self.device)
batch_next_state[key] = self.next_states[key][idx].to(self.device) else:
else: next_idx = (idx + 1) % self.capacity
# Memory-optimized approach - get next_state from the next index batch_next_state[key] = self.states[key][next_idx].to(self.device)
next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device) # Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
# Apply image augmentation in a batched way if needed
if self.use_drq and image_keys: if self.use_drq and image_keys:
# Concatenate all images from state and next_state # Concatenate all images from state and next_state
all_images = [] all_images = []
@@ -280,19 +289,6 @@ class ReplayBuffer:
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots # Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
return BatchTransition( return BatchTransition(
state=batch_state, state=batch_state,
action=batch_actions, action=batch_actions,