From c36de3a3e81b55093e41f4c1b9b7ca71df732bb1 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 12 May 2026 15:27:23 +0200 Subject: [PATCH] fix(smolvla2): enqueue full chunk via predict_action_chunk MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``LowLevelForward`` was calling ``select_action()`` once per ``chunk_hz`` tick. SmolVLA's ``select_action`` is a thin queue-pop: it returns one action per call and only re-runs the expensive flow-matching forward when its private internal queue empties. Result: we got one action back per chunk_hz tick (1Hz default), ``DispatchAction`` at ctrl_hz=30 popped it instantly, then queue sat empty for ~1s waiting for the next tick. Net throughput was 1 dispatched action/sec instead of the 30 we wanted. Switch to ``predict_action_chunk`` and enqueue every step of the returned ``(batch, n_action_steps, action_dim)`` chunk. Refresh only when the queue is below half a chunk so we don't burn one flow-matching forward per chunk_hz tick — saves ~5x inference cost on this hot path. At ctrl_hz=30, chunk_size=50, the queue drains in ~1.7s before the next refresh, giving smooth dispatch at the control rate the robot was trained on. Side effect: ``state['last_chunk_size']`` records how many actions the most recent chunk produced — useful for the panel later if we want to surface "chunks generated" alongside "dispatched". Co-Authored-By: Claude Opus 4.7 (1M context) --- .../policies/smolvla2/inference/steps.py | 61 ++++++++++++++----- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index fd285c419..29cb7be83 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -92,17 +92,28 @@ class LowLevelForward(InferenceStep): if self.policy is None or self.observation_provider is None: return None if not state.get("task"): - # No task yet → nothing useful to condition on. return None + + # SmolVLA produces *action chunks* (typically 50 steps via + # flow-matching). The expensive part is the chunk forward; + # popping one action per dispatch tick is essentially free. + # Only regenerate when the queue is low so we don't burn one + # full chunk forward per chunk_hz tick when most of the + # previous chunk is still buffered. + queue = state.setdefault("action_queue", []) + chunk_size = getattr(self.policy.config, "chunk_size", None) or getattr( + self.policy.config, "n_action_steps", 50 + ) + # Refresh threshold: keep at least half a chunk buffered. + if len(queue) > max(1, chunk_size // 2): + return None + observation = self.observation_provider() if observation is None: return None - # SmolVLA's ``select_action`` expects the full preprocessed - # batch, including ``OBS_LANGUAGE_TOKENS`` / - # ``OBS_LANGUAGE_ATTENTION_MASK``. The observation provider - # only returns image / state features (the runtime drives - # messages itself), so build a low-level prompt from current - # runtime state and tokenize it inline. + + # Same prompt construction as before — task + plan + memory, + # optional current subtask — then merge into the obs batch. ctx = _control_context_messages(state) if state.get("current_subtask"): ctx = ctx + [{"role": "assistant", "content": state["current_subtask"]}] @@ -115,16 +126,38 @@ class LowLevelForward(InferenceStep): observation = dict(observation) observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"] observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"] + try: - action = self.policy.select_action(observation) + # ``predict_action_chunk`` returns the *full* chunk shape + # ``(batch, n_action_steps, action_dim)``. Enqueue every + # step so DispatchAction at ctrl_hz can drain them + # smoothly until the next refresh. + chunk = self.policy.predict_action_chunk(observation) except Exception as exc: # noqa: BLE001 - logger.warning("select_action failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG)) - push_log(state, f" [warn] select_action failed: {type(exc).__name__}: {exc}") + logger.warning( + "predict_action_chunk failed: %s", + exc, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + push_log( + state, + f" [warn] predict_action_chunk failed: " + f"{type(exc).__name__}: {exc}", + ) return None - # SmolVLA returns a single action; if the underlying policy - # streams chunks, split per-step here. For v1 we just enqueue - # the result. - state.setdefault("action_queue", []).append(action) + + # ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push + # each step as a ``(1, action_dim)`` tensor so the existing + # action executor's batch-squeeze logic works unchanged. + if chunk.ndim == 3: + chunk_iter = chunk[0] # ``(n_action_steps, action_dim)`` + elif chunk.ndim == 2: + chunk_iter = chunk + else: + chunk_iter = chunk.unsqueeze(0) + for step in chunk_iter: + queue.append(step.unsqueeze(0)) + state["last_chunk_size"] = int(chunk_iter.shape[0]) return None