mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
revert initialization to empty
This commit is contained in:
@@ -150,17 +150,18 @@ class ReplayBuffer:
|
||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||
action_shape = action.squeeze(0).shape
|
||||
|
||||
# Pre-allocate tensors for storage
|
||||
self.states = {
|
||||
key: torch.zeros((self.capacity, *shape), device=self.storage_device)
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
self.actions = torch.zeros((self.capacity, *action_shape), device=self.storage_device)
|
||||
self.rewards = torch.zeros((self.capacity,), device=self.storage_device)
|
||||
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
|
||||
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach: store states and next_states separately
|
||||
self.next_states = {
|
||||
key: torch.zeros((self.capacity, *shape), device=self.storage_device)
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
else:
|
||||
@@ -168,8 +169,8 @@ class ReplayBuffer:
|
||||
# Just create a reference to states for consistent API
|
||||
self.next_states = self.states # Just a reference for API consistency
|
||||
|
||||
self.terminated = torch.zeros((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncated = torch.zeros((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.terminated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
|
||||
# Initialize storage for complementary_info
|
||||
self.has_complementary_info = complementary_info is not None
|
||||
@@ -182,12 +183,12 @@ class ReplayBuffer:
|
||||
for key, value in complementary_info.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
value_shape = value.squeeze(0).shape
|
||||
self.complementary_info[key] = torch.zeros(
|
||||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int | float)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.zeros((self.capacity,), device=self.storage_device)
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user