mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
fix(evo1): infer batch size after normalizing image dims
`_collect_image_batches` read `batch_size = batch[camera_keys[0]].shape[0]` before normalizing per-camera tensors to `(B, C, H, W)`. For an unbatched `(C, H, W)` input (which the function tries to support via the `image.dim() == 3` branch), this picked up the channel count `C` instead of the real batch size, making the subsequent per-sample loop iterate `C` times and indexing go out of bounds. Normalize each camera tensor up-front, then read `batch_size` from the normalized batch dim. Adds `test_collect_image_batches_handles_unbatched_chw` covering the regression. Reported by Copilot review on huggingface/lerobot#3545.
This commit is contained in:
@@ -180,6 +180,23 @@ def test_evo1_policy_forward_and_inference_use_batched_embedding(monkeypatch):
|
||||
assert selected.shape == (2, ACTION_DIM)
|
||||
|
||||
|
||||
def test_collect_image_batches_handles_unbatched_chw(monkeypatch):
|
||||
# Regression for an issue where batch_size was read from shape[0] before normalizing
|
||||
# per-camera tensor dims, so an unbatched (C, H, W) input was treated as batch_size=C.
|
||||
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
|
||||
policy = modeling_evo1.EVO1Policy(make_config())
|
||||
batch = {
|
||||
OBS_STATE: torch.randn(1, STATE_DIM),
|
||||
f"{OBS_IMAGES}.front": torch.rand(3, 16, 16),
|
||||
}
|
||||
|
||||
image_batches, image_masks = policy._collect_image_batches(batch)
|
||||
|
||||
assert len(image_batches) == 1
|
||||
assert len(image_batches[0]) == policy.config.max_views
|
||||
assert image_masks.tolist() == [[True, False]]
|
||||
|
||||
|
||||
def test_evo1_action_mask_accepts_chunk_size_one(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "EVO1", DummyEVO1)
|
||||
config = make_config(chunk_size=1, n_action_steps=1)
|
||||
|
||||
Reference in New Issue
Block a user