only call .detach().cpu() once per caemra instead of once per image

This commit is contained in:
Martino Russi
2026-07-01 17:45:16 +00:00
parent 9423deda02
commit 1df19ae468
+6 -1
View File
@@ -342,10 +342,15 @@ class EVO1Policy(PreTrainedPolicy):
image_batches: list[list[Tensor]] = []
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
cpu_images: dict[str, Tensor] = {
camera_key: normalized[camera_key].detach().cpu()
for camera_key in camera_keys[: self.config.max_views]
}
for batch_index in range(batch_size):
sample_images: list[Tensor] = []
for camera_key in camera_keys[: self.config.max_views]:
sample_images.append(normalized[camera_key][batch_index].detach().cpu())
sample_images.append(cpu_images[camera_key][batch_index])
if not sample_images:
raise ValueError("EVO1 received a batch without any image tensor.")
while len(sample_images) < self.config.max_views: