mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample() (#3372)
This commit is contained in:
+17
-21
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import threading
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
@@ -115,6 +116,7 @@ class ReplayBuffer:
|
|||||||
self.size = 0
|
self.size = 0
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
self.optimize_memory = optimize_memory
|
self.optimize_memory = optimize_memory
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
# Track episode boundaries for memory optimization
|
# Track episode boundaries for memory optimization
|
||||||
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
||||||
@@ -198,6 +200,7 @@ class ReplayBuffer:
|
|||||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
complementary_info: dict[str, torch.Tensor] | None = None,
|
||||||
):
|
):
|
||||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||||
|
with self._lock:
|
||||||
# Initialize storage if this is the first transition
|
# Initialize storage if this is the first transition
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
||||||
@@ -217,7 +220,6 @@ class ReplayBuffer:
|
|||||||
|
|
||||||
# Handle complementary_info if provided and storage is initialized
|
# Handle complementary_info if provided and storage is initialized
|
||||||
if complementary_info is not None and self.has_complementary_info:
|
if complementary_info is not None and self.has_complementary_info:
|
||||||
# Store the complementary_info
|
|
||||||
for key in self.complementary_info_keys:
|
for key in self.complementary_info_keys:
|
||||||
if key in complementary_info:
|
if key in complementary_info:
|
||||||
value = complementary_info[key]
|
value = complementary_info[key]
|
||||||
@@ -234,32 +236,39 @@ class ReplayBuffer:
|
|||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
batch_size = min(batch_size, self.size)
|
batch_size = min(batch_size, self.size)
|
||||||
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
||||||
|
|
||||||
# Random indices for sampling - create on the same device as storage
|
|
||||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||||
|
|
||||||
# Identify image keys that need augmentation
|
|
||||||
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
|
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
|
||||||
|
|
||||||
# Create batched state and next_state
|
|
||||||
batch_state = {}
|
batch_state = {}
|
||||||
batch_next_state = {}
|
batch_next_state = {}
|
||||||
|
|
||||||
# First pass: load all state tensors to target device
|
|
||||||
for key in self.states:
|
for key in self.states:
|
||||||
batch_state[key] = self.states[key][idx].to(self.device)
|
batch_state[key] = self.states[key][idx].to(self.device)
|
||||||
|
|
||||||
if not self.optimize_memory:
|
if not self.optimize_memory:
|
||||||
# Standard approach - load next_states directly
|
|
||||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||||
else:
|
else:
|
||||||
# Memory-optimized approach - get next_state from the next index
|
|
||||||
next_idx = (idx + 1) % self.capacity
|
next_idx = (idx + 1) % self.capacity
|
||||||
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
||||||
|
|
||||||
# Apply image augmentation in a batched way if needed
|
# Sample other tensors
|
||||||
|
batch_actions = self.actions[idx].to(self.device)
|
||||||
|
batch_rewards = self.rewards[idx].to(self.device)
|
||||||
|
batch_dones = self.dones[idx].to(self.device).float()
|
||||||
|
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||||
|
|
||||||
|
# Sample complementary_info if available
|
||||||
|
batch_complementary_info = None
|
||||||
|
if self.has_complementary_info:
|
||||||
|
batch_complementary_info = {}
|
||||||
|
for key in self.complementary_info_keys:
|
||||||
|
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
||||||
|
|
||||||
if self.use_drq and image_keys:
|
if self.use_drq and image_keys:
|
||||||
# Concatenate all images from state and next_state
|
# Concatenate all images from state and next_state
|
||||||
all_images = []
|
all_images = []
|
||||||
@@ -280,19 +289,6 @@ class ReplayBuffer:
|
|||||||
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
||||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||||
|
|
||||||
# Sample other tensors
|
|
||||||
batch_actions = self.actions[idx].to(self.device)
|
|
||||||
batch_rewards = self.rewards[idx].to(self.device)
|
|
||||||
batch_dones = self.dones[idx].to(self.device).float()
|
|
||||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
|
||||||
|
|
||||||
# Sample complementary_info if available
|
|
||||||
batch_complementary_info = None
|
|
||||||
if self.has_complementary_info:
|
|
||||||
batch_complementary_info = {}
|
|
||||||
for key in self.complementary_info_keys:
|
|
||||||
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
|
||||||
|
|
||||||
return BatchTransition(
|
return BatchTransition(
|
||||||
state=batch_state,
|
state=batch_state,
|
||||||
action=batch_actions,
|
action=batch_actions,
|
||||||
|
|||||||
Reference in New Issue
Block a user