mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(evo1): infer batch size after normalizing image dims
`_collect_image_batches` read `batch_size = batch[camera_keys[0]].shape[0]` before normalizing per-camera tensors to `(B, C, H, W)`. For an unbatched `(C, H, W)` input (which the function tries to support via the `image.dim() == 3` branch), this picked up the channel count `C` instead of the real batch size, making the subsequent per-sample loop iterate `C` times and indexing go out of bounds. Normalize each camera tensor up-front, then read `batch_size` from the normalized batch dim. Adds `test_collect_image_batches_handles_unbatched_chw` covering the regression. Reported by Copilot review on huggingface/lerobot#3545.
This commit is contained in:
@@ -299,23 +299,30 @@ class EVO1Policy(PreTrainedPolicy):
|
||||
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
|
||||
if not camera_keys:
|
||||
raise ValueError("EVO1 requires at least one visual observation feature.")
|
||||
batch_size = batch[camera_keys[0]].shape[0]
|
||||
|
||||
# Normalize each camera tensor to (B, C, H, W) up-front so that batch_size is read
|
||||
# from a real batch dim and not from C in the unbatched (C, H, W) case.
|
||||
normalized: dict[str, Tensor] = {}
|
||||
for camera_key in camera_keys[: self.config.max_views]:
|
||||
image = batch[camera_key]
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
elif image.dim() == 5:
|
||||
image = image[:, -1]
|
||||
elif image.dim() != 4:
|
||||
raise ValueError(
|
||||
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
|
||||
)
|
||||
normalized[camera_key] = image
|
||||
|
||||
batch_size = normalized[camera_keys[0]].shape[0]
|
||||
image_batches: list[list[Tensor]] = []
|
||||
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
|
||||
|
||||
for batch_index in range(batch_size):
|
||||
sample_images: list[Tensor] = []
|
||||
for camera_key in camera_keys[: self.config.max_views]:
|
||||
image = batch[camera_key]
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
elif image.dim() == 5:
|
||||
image = image[:, -1]
|
||||
elif image.dim() != 4:
|
||||
raise ValueError(
|
||||
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
|
||||
)
|
||||
sample_images.append(image[batch_index].detach().cpu())
|
||||
sample_images.append(normalized[camera_key][batch_index].detach().cpu())
|
||||
if not sample_images:
|
||||
raise ValueError("EVO1 received a batch without any image tensor.")
|
||||
while len(sample_images) < self.config.max_views:
|
||||
|
||||
@@ -180,6 +180,23 @@ def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
|
||||
assert selected.shape == (2, ACTION_DIM)
|
||||
|
||||
|
||||
def test_collect_image_batches_handles_unbatched_chw(monkeypatch):
|
||||
# Regression for an issue where batch_size was read from shape[0] before normalizing
|
||||
# per-camera tensor dims, so an unbatched (C, H, W) input was treated as batch_size=C.
|
||||
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
|
||||
policy = modeling_evo1.EVO1Policy(make_config())
|
||||
batch = {
|
||||
OBS_STATE: torch.randn(1, STATE_DIM),
|
||||
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
|
||||
}
|
||||
|
||||
image_batches, image_masks = policy._collect_image_batches(batch)
|
||||
|
||||
assert len(image_batches) == 1
|
||||
assert len(image_batches[0]) == policy.config.max_views
|
||||
assert image_masks.tolist() == [[True, False]]
|
||||
|
||||
|
||||
def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
|
||||
config = make_config(chunk_size=1, n_action_steps=1)
|
||||
|
||||
Reference in New Issue
Block a user