mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
3a2308d86f
commit
88d26ae976
@@ -248,7 +248,7 @@ class ReplayBuffer:
|
||||
# Initialize complementary_info storage
|
||||
self.complementary_info_keys = []
|
||||
self.complementary_info_storage = {}
|
||||
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
@@ -291,11 +291,9 @@ class ReplayBuffer:
|
||||
if isinstance(value, torch.Tensor):
|
||||
shape = value.shape if value.ndim > 0 else (1,)
|
||||
self.complementary_info_storage[key] = torch.zeros(
|
||||
(self.capacity, *shape),
|
||||
dtype=value.dtype,
|
||||
device=self.storage_device
|
||||
(self.capacity, *shape), dtype=value.dtype, device=self.storage_device
|
||||
)
|
||||
|
||||
|
||||
# Store the value
|
||||
if key in self.complementary_info_storage:
|
||||
if isinstance(value, torch.Tensor):
|
||||
@@ -304,7 +302,7 @@ class ReplayBuffer:
|
||||
# For non-tensor values (like grasp_penalty)
|
||||
self.complementary_info_storage[key][self.position] = torch.tensor(
|
||||
value, device=self.storage_device
|
||||
)
|
||||
)
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
@@ -366,7 +364,7 @@ class ReplayBuffer:
|
||||
|
||||
# Add complementary_info to batch if it exists
|
||||
batch_complementary_info = {}
|
||||
if hasattr(self, 'complementary_info_keys') and self.complementary_info_keys:
|
||||
if hasattr(self, "complementary_info_keys") and self.complementary_info_keys:
|
||||
for key in self.complementary_info_keys:
|
||||
if key in self.complementary_info_storage:
|
||||
batch_complementary_info[key] = self.complementary_info_storage[key][idx].to(self.device)
|
||||
|
||||
Reference in New Issue
Block a user