diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 625f4bf6a..cf4a1dae2 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -150,17 +150,18 @@ class ReplayBuffer: state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} action_shape = action.squeeze(0).shape + # Pre-allocate tensors for storage self.states = { - key: torch.zeros((self.capacity, *shape), device=self.storage_device) + key: torch.empty((self.capacity, *shape), device=self.storage_device) for key, shape in state_shapes.items() } - self.actions = torch.zeros((self.capacity, *action_shape), device=self.storage_device) - self.rewards = torch.zeros((self.capacity,), device=self.storage_device) + self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device) + self.rewards = torch.empty((self.capacity,), device=self.storage_device) if not self.optimize_memory: # Standard approach: store states and next_states separately self.next_states = { - key: torch.zeros((self.capacity, *shape), device=self.storage_device) + key: torch.empty((self.capacity, *shape), device=self.storage_device) for key, shape in state_shapes.items() } else: @@ -168,8 +169,8 @@ class ReplayBuffer: # Just create a reference to states for consistent API self.next_states = self.states # Just a reference for API consistency - self.terminated = torch.zeros((self.capacity,), dtype=torch.bool, device=self.storage_device) - self.truncated = torch.zeros((self.capacity,), dtype=torch.bool, device=self.storage_device) + self.terminated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + self.truncated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) # Initialize storage for complementary_info self.has_complementary_info = complementary_info is not None @@ -182,12 +183,12 @@ class ReplayBuffer: for key, value in complementary_info.items(): if isinstance(value, torch.Tensor): value_shape = value.squeeze(0).shape - self.complementary_info[key] = torch.zeros( + self.complementary_info[key] = torch.empty( (self.capacity, *value_shape), device=self.storage_device ) elif isinstance(value, (int | float)): # Handle scalar values similar to reward - self.complementary_info[key] = torch.zeros((self.capacity,), device=self.storage_device) + self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) else: raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")