mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
fix(smolvla2): drop pixel_values from select_message generate path
SmolVLA's image preprocessor sizes frames to whatever the action expert was trained on, but SmolVLM's standard vision tower expects its own default tile grid (e.g. 384/14 → 27×27 patches). The mismatch surfaces deep in the post-vision reshape as ``RuntimeError: shape '[2, 34, 34, 768]' is invalid for input of size 1843200`` — the model has 1200 patches but expects 34×34=1156. Drop ``pixel_values`` from ``vlm.generate(...)`` so SmolVLM runs as a text-only LM at REPL time. The high-level branches (subtask / plan / memory) are dominated by their text context anyway, so this is acceptable for dry-run inference. VQA loses its image grounding — that will be marked as expected for the dry-run path until a follow-up either re-processes images through SmolVLM's own ``ImageProcessor`` to match its tile grid, or gives ``vlm_with_expert`` a real AR text decode mode that handles state and image embeddings the way training does. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -324,23 +324,23 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
# Collect any image features the runtime merged in. SmolVLM
|
||||
# expects ``pixel_values`` shape ``[B, num_images, C, H, W]``;
|
||||
# individual ``observation.images.*`` features are typically
|
||||
# ``[B, C, H, W]`` after the preprocessor, so stack them on a
|
||||
# new ``num_images`` axis.
|
||||
image_tensors: list[Tensor] = []
|
||||
for k, v in batch.items():
|
||||
if (
|
||||
isinstance(k, str)
|
||||
and k.startswith("observation.images.")
|
||||
and isinstance(v, Tensor)
|
||||
):
|
||||
image_tensors.append(v if v.ndim == 4 else v.unsqueeze(0))
|
||||
pixel_values = (
|
||||
torch.stack(image_tensors, dim=1) if image_tensors else None
|
||||
)
|
||||
|
||||
# NOTE: we deliberately do *not* forward ``pixel_values`` to
|
||||
# ``vlm.generate``. The dataset's images go through SmolVLA's
|
||||
# custom preprocessor (resize / normalise to whatever shape
|
||||
# the action expert was trained on), but SmolVLM's standard
|
||||
# vision tower expects images sized to its own default tile
|
||||
# grid (e.g. 384/14 → 27×27 patches). The mismatch surfaces
|
||||
# as ``RuntimeError: shape '[2, 34, 34, 768]' is invalid for
|
||||
# input of size <other>`` deep in the post-vision reshape.
|
||||
#
|
||||
# For the dry-run REPL the high-level branches (subtask /
|
||||
# plan / memory) are dominated by their text context anyway,
|
||||
# so running text-only generation through SmolVLM is good
|
||||
# enough. Restoring full vision conditioning at inference
|
||||
# would mean either re-processing the images through the
|
||||
# backbone's own ``ImageProcessor`` (and matching SmolVLA2
|
||||
# training shape) or giving ``vlm_with_expert`` a real AR
|
||||
# text decode mode — both are bigger follow-ups.
|
||||
gen_kwargs: dict[str, Any] = {
|
||||
"input_ids": lang_tokens,
|
||||
"attention_mask": lang_masks.long() if lang_masks.dtype == torch.bool else lang_masks,
|
||||
@@ -353,8 +353,6 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
gen_kwargs["top_p"] = top_p
|
||||
if eos_token_id is not None:
|
||||
gen_kwargs["eos_token_id"] = eos_token_id
|
||||
if pixel_values is not None:
|
||||
gen_kwargs["pixel_values"] = pixel_values
|
||||
|
||||
gen_ids = vlm.generate(**gen_kwargs)
|
||||
# ``vlm.generate`` returns the prompt + new tokens; slice off
|
||||
|
||||
Reference in New Issue
Block a user