diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index f4d0ebf6b..1c65a26ec 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -324,23 +324,23 @@ class SmolVLA2Policy(SmolVLAPolicy): lang_tokens = batch[OBS_LANGUAGE_TOKENS] lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] - # Collect any image features the runtime merged in. SmolVLM - # expects ``pixel_values`` shape ``[B, num_images, C, H, W]``; - # individual ``observation.images.*`` features are typically - # ``[B, C, H, W]`` after the preprocessor, so stack them on a - # new ``num_images`` axis. - image_tensors: list[Tensor] = [] - for k, v in batch.items(): - if ( - isinstance(k, str) - and k.startswith("observation.images.") - and isinstance(v, Tensor) - ): - image_tensors.append(v if v.ndim == 4 else v.unsqueeze(0)) - pixel_values = ( - torch.stack(image_tensors, dim=1) if image_tensors else None - ) - + # NOTE: we deliberately do *not* forward ``pixel_values`` to + # ``vlm.generate``. The dataset's images go through SmolVLA's + # custom preprocessor (resize / normalise to whatever shape + # the action expert was trained on), but SmolVLM's standard + # vision tower expects images sized to its own default tile + # grid (e.g. 384/14 → 27×27 patches). The mismatch surfaces + # as ``RuntimeError: shape '[2, 34, 34, 768]' is invalid for + # input of size `` deep in the post-vision reshape. + # + # For the dry-run REPL the high-level branches (subtask / + # plan / memory) are dominated by their text context anyway, + # so running text-only generation through SmolVLM is good + # enough. Restoring full vision conditioning at inference + # would mean either re-processing the images through the + # backbone's own ``ImageProcessor`` (and matching SmolVLA2 + # training shape) or giving ``vlm_with_expert`` a real AR + # text decode mode — both are bigger follow-ups. gen_kwargs: dict[str, Any] = { "input_ids": lang_tokens, "attention_mask": lang_masks.long() if lang_masks.dtype == torch.bool else lang_masks, @@ -353,8 +353,6 @@ class SmolVLA2Policy(SmolVLAPolicy): gen_kwargs["top_p"] = top_p if eos_token_id is not None: gen_kwargs["eos_token_id"] = eos_token_id - if pixel_values is not None: - gen_kwargs["pixel_values"] = pixel_values gen_ids = vlm.generate(**gen_kwargs) # ``vlm.generate`` returns the prompt + new tokens; slice off