mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
revert initialization to empty
This commit is contained in:
@@ -150,17 +150,18 @@ class ReplayBuffer:
|
|||||||
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
|
||||||
action_shape = action.squeeze(0).shape
|
action_shape = action.squeeze(0).shape
|
||||||
|
|
||||||
|
# Pre-allocate tensors for storage
|
||||||
self.states = {
|
self.states = {
|
||||||
key: torch.zeros((self.capacity, *shape), device=self.storage_device)
|
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||||
for key, shape in state_shapes.items()
|
for key, shape in state_shapes.items()
|
||||||
}
|
}
|
||||||
self.actions = torch.zeros((self.capacity, *action_shape), device=self.storage_device)
|
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
|
||||||
self.rewards = torch.zeros((self.capacity,), device=self.storage_device)
|
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||||
|
|
||||||
if not self.optimize_memory:
|
if not self.optimize_memory:
|
||||||
# Standard approach: store states and next_states separately
|
# Standard approach: store states and next_states separately
|
||||||
self.next_states = {
|
self.next_states = {
|
||||||
key: torch.zeros((self.capacity, *shape), device=self.storage_device)
|
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||||
for key, shape in state_shapes.items()
|
for key, shape in state_shapes.items()
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@@ -168,8 +169,8 @@ class ReplayBuffer:
|
|||||||
# Just create a reference to states for consistent API
|
# Just create a reference to states for consistent API
|
||||||
self.next_states = self.states # Just a reference for API consistency
|
self.next_states = self.states # Just a reference for API consistency
|
||||||
|
|
||||||
self.terminated = torch.zeros((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
self.terminated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||||
self.truncated = torch.zeros((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
self.truncated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||||
|
|
||||||
# Initialize storage for complementary_info
|
# Initialize storage for complementary_info
|
||||||
self.has_complementary_info = complementary_info is not None
|
self.has_complementary_info = complementary_info is not None
|
||||||
@@ -182,12 +183,12 @@ class ReplayBuffer:
|
|||||||
for key, value in complementary_info.items():
|
for key, value in complementary_info.items():
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
value_shape = value.squeeze(0).shape
|
value_shape = value.squeeze(0).shape
|
||||||
self.complementary_info[key] = torch.zeros(
|
self.complementary_info[key] = torch.empty(
|
||||||
(self.capacity, *value_shape), device=self.storage_device
|
(self.capacity, *value_shape), device=self.storage_device
|
||||||
)
|
)
|
||||||
elif isinstance(value, (int | float)):
|
elif isinstance(value, (int | float)):
|
||||||
# Handle scalar values similar to reward
|
# Handle scalar values similar to reward
|
||||||
self.complementary_info[key] = torch.zeros((self.capacity,), device=self.storage_device)
|
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
|
raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user