fix(evo1): infer batch size after normalizing image dims

`_collect_image_batches` read `batch_size = batch[camera_keys[0]].shape[0]`
before normalizing per-camera tensors to `(B, C, H, W)`. For an unbatched
`(C, H, W)` input (which the function tries to support via the `image.dim() == 3`
branch), this picked up the channel count `C` instead of the real batch size,
making the subsequent per-sample loop iterate `C` times and indexing go
out of bounds.

Normalize each camera tensor up-front, then read `batch_size` from the
normalized batch dim. Adds `test_collect_image_batches_handles_unbatched_chw`
covering the regression.

Reported by Copilot review on huggingface/lerobot#3545.
This commit is contained in:
javadcc_mac
2026-05-10 11:29:23 +08:00
parent 8df8d3d866
commit 48269dddb3
2 changed files with 35 additions and 11 deletions
+17
View File
@@ -180,6 +180,23 @@ def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
assert selected.shape == (2, ACTION_DIM)
def test_collect_image_batches_handles_unbatched_chw(monkeypatch):
# Regression for an issue where batch_size was read from shape[0] before normalizing
# per-camera tensor dims, so an unbatched (C, H, W) input was treated as batch_size=C.
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
policy = modeling_evo1.EVO1Policy(make_config())
batch = {
OBS_STATE: torch.randn(1, STATE_DIM),
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
}
image_batches, image_masks = policy._collect_image_batches(batch)
assert len(image_batches) == 1
assert len(image_batches[0]) == policy.config.max_views
assert image_masks.tolist() == [[True, False]]
def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch):
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
config = make_config(chunk_size=1, n_action_steps=1)