diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index fd285c419..29cb7be83 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -92,17 +92,28 @@ class LowLevelForward(InferenceStep): if self.policy is None or self.observation_provider is None: return None if not state.get("task"): - # No task yet → nothing useful to condition on. return None + + # SmolVLA produces *action chunks* (typically 50 steps via + # flow-matching). The expensive part is the chunk forward; + # popping one action per dispatch tick is essentially free. + # Only regenerate when the queue is low so we don't burn one + # full chunk forward per chunk_hz tick when most of the + # previous chunk is still buffered. + queue = state.setdefault("action_queue", []) + chunk_size = getattr(self.policy.config, "chunk_size", None) or getattr( + self.policy.config, "n_action_steps", 50 + ) + # Refresh threshold: keep at least half a chunk buffered. + if len(queue) > max(1, chunk_size // 2): + return None + observation = self.observation_provider() if observation is None: return None - # SmolVLA's ``select_action`` expects the full preprocessed - # batch, including ``OBS_LANGUAGE_TOKENS`` / - # ``OBS_LANGUAGE_ATTENTION_MASK``. The observation provider - # only returns image / state features (the runtime drives - # messages itself), so build a low-level prompt from current - # runtime state and tokenize it inline. + + # Same prompt construction as before — task + plan + memory, + # optional current subtask — then merge into the obs batch. ctx = _control_context_messages(state) if state.get("current_subtask"): ctx = ctx + [{"role": "assistant", "content": state["current_subtask"]}] @@ -115,16 +126,38 @@ class LowLevelForward(InferenceStep): observation = dict(observation) observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"] observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"] + try: - action = self.policy.select_action(observation) + # ``predict_action_chunk`` returns the *full* chunk shape + # ``(batch, n_action_steps, action_dim)``. Enqueue every + # step so DispatchAction at ctrl_hz can drain them + # smoothly until the next refresh. + chunk = self.policy.predict_action_chunk(observation) except Exception as exc: # noqa: BLE001 - logger.warning("select_action failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG)) - push_log(state, f" [warn] select_action failed: {type(exc).__name__}: {exc}") + logger.warning( + "predict_action_chunk failed: %s", + exc, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + push_log( + state, + f" [warn] predict_action_chunk failed: " + f"{type(exc).__name__}: {exc}", + ) return None - # SmolVLA returns a single action; if the underlying policy - # streams chunks, split per-step here. For v1 we just enqueue - # the result. - state.setdefault("action_queue", []).append(action) + + # ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push + # each step as a ``(1, action_dim)`` tensor so the existing + # action executor's batch-squeeze logic works unchanged. + if chunk.ndim == 3: + chunk_iter = chunk[0] # ``(n_action_steps, action_dim)`` + elif chunk.ndim == 2: + chunk_iter = chunk + else: + chunk_iter = chunk.unsqueeze(0) + for step in chunk_iter: + queue.append(step.unsqueeze(0)) + state["last_chunk_size"] = int(chunk_iter.shape[0]) return None