revert initialization to empty

This commit is contained in:
Michel Aractingi
2025-12-17 16:36:59 +01:00
parent 9014f9a7c5
commit e3539cb78e
+9 -8
View File
@@ -150,17 +150,18 @@ class ReplayBuffer:
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
action_shape = action.squeeze(0).shape
# Pre-allocate tensors for storage
self.states = {
key: torch.zeros((self.capacity, *shape), device=self.storage_device)
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
self.actions = torch.zeros((self.capacity, *action_shape), device=self.storage_device)
self.rewards = torch.zeros((self.capacity,), device=self.storage_device)
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
if not self.optimize_memory:
# Standard approach: store states and next_states separately
self.next_states = {
key: torch.zeros((self.capacity, *shape), device=self.storage_device)
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
else:
@@ -168,8 +169,8 @@ class ReplayBuffer:
# Just create a reference to states for consistent API
self.next_states = self.states # Just a reference for API consistency
self.terminated = torch.zeros((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncated = torch.zeros((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.terminated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
# Initialize storage for complementary_info
self.has_complementary_info = complementary_info is not None
@@ -182,12 +183,12 @@ class ReplayBuffer:
for key, value in complementary_info.items():
if isinstance(value, torch.Tensor):
value_shape = value.squeeze(0).shape
self.complementary_info[key] = torch.zeros(
self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device
)
elif isinstance(value, (int | float)):
# Handle scalar values similar to reward
self.complementary_info[key] = torch.zeros((self.capacity,), device=self.storage_device)
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else:
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")