From 48269dddb3eee34e91a57454fa4059dbfe604e9f Mon Sep 17 00:00:00 2001 From: javadcc_mac Date: Sun, 10 May 2026 11:29:23 +0800 Subject: [PATCH] fix(evo1): infer batch size after normalizing image dims `_collect_image_batches` read `batch_size = batch[camera_keys[0]].shape[0]` before normalizing per-camera tensors to `(B, C, H, W)`. For an unbatched `(C, H, W)` input (which the function tries to support via the `image.dim() == 3` branch), this picked up the channel count `C` instead of the real batch size, making the subsequent per-sample loop iterate `C` times and indexing go out of bounds. Normalize each camera tensor up-front, then read `batch_size` from the normalized batch dim. Adds `test_collect_image_batches_handles_unbatched_chw` covering the regression. Reported by Copilot review on huggingface/lerobot#3545. --- src/lerobot/policies/evo1/modeling_evo1.py | 29 ++++++++++++++-------- tests/policies/evo1/test_evo1.py | 17 +++++++++++++ 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/lerobot/policies/evo1/modeling_evo1.py b/src/lerobot/policies/evo1/modeling_evo1.py index 474fd52a5..91459d722 100644 --- a/src/lerobot/policies/evo1/modeling_evo1.py +++ b/src/lerobot/policies/evo1/modeling_evo1.py @@ -299,23 +299,30 @@ class EVO1Policy(PreTrainedPolicy): camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}.")) if not camera_keys: raise ValueError("EVO1 requires at least one visual observation feature.") - batch_size = batch[camera_keys[0]].shape[0] + + # Normalize each camera tensor to (B, C, H, W) up-front so that batch_size is read + # from a real batch dim and not from C in the unbatched (C, H, W) case. + normalized: dict[str, Tensor] = {} + for camera_key in camera_keys[: self.config.max_views]: + image = batch[camera_key] + if image.dim() == 3: + image = image.unsqueeze(0) + elif image.dim() == 5: + image = image[:, -1] + elif image.dim() != 4: + raise ValueError( + f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}" + ) + normalized[camera_key] = image + + batch_size = normalized[camera_keys[0]].shape[0] image_batches: list[list[Tensor]] = [] image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool) for batch_index in range(batch_size): sample_images: list[Tensor] = [] for camera_key in camera_keys[: self.config.max_views]: - image = batch[camera_key] - if image.dim() == 3: - image = image.unsqueeze(0) - elif image.dim() == 5: - image = image[:, -1] - elif image.dim() != 4: - raise ValueError( - f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}" - ) - sample_images.append(image[batch_index].detach().cpu()) + sample_images.append(normalized[camera_key][batch_index].detach().cpu()) if not sample_images: raise ValueError("EVO1 received a batch without any image tensor.") while len(sample_images) < self.config.max_views: diff --git a/tests/policies/evo1/test_evo1.py b/tests/policies/evo1/test_evo1.py index 5bf170397..706c1903f 100644 --- a/tests/policies/evo1/test_evo1.py +++ b/tests/policies/evo1/test_evo1.py @@ -180,6 +180,23 @@ def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch): assert selected.shape == (2, ACTION_DIM) +def test_collect_image_batches_handles_unbatched_chw(monkeypatch): + # Regression for an issue where batch_size was read from shape[0] before normalizing + # per-camera tensor dims, so an unbatched (C, H, W) input was treated as batch_size=C. + monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1) + policy = modeling_evo1.EVO1Policy(make_config()) + batch = { + OBS_STATE: torch.randn(1, STATE_DIM), + f"{OBS_IMAGES}.front": torch.rand(3, 16, 16), + } + + image_batches, image_masks = policy._collect_image_batches(batch) + + assert len(image_batches) == 1 + assert len(image_batches[0]) == policy.config.max_views + assert image_masks.tolist() == [[True, False]] + + def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch): monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1) config = make_config(chunk_size=1, n_action_steps=1)