mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 21:50:03 +00:00
fix(smolvla2): enqueue full chunk via predict_action_chunk
``LowLevelForward`` was calling ``select_action()`` once per ``chunk_hz`` tick. SmolVLA's ``select_action`` is a thin queue-pop: it returns one action per call and only re-runs the expensive flow-matching forward when its private internal queue empties. Result: we got one action back per chunk_hz tick (1Hz default), ``DispatchAction`` at ctrl_hz=30 popped it instantly, then queue sat empty for ~1s waiting for the next tick. Net throughput was 1 dispatched action/sec instead of the 30 we wanted. Switch to ``predict_action_chunk`` and enqueue every step of the returned ``(batch, n_action_steps, action_dim)`` chunk. Refresh only when the queue is below half a chunk so we don't burn one flow-matching forward per chunk_hz tick — saves ~5x inference cost on this hot path. At ctrl_hz=30, chunk_size=50, the queue drains in ~1.7s before the next refresh, giving smooth dispatch at the control rate the robot was trained on. Side effect: ``state['last_chunk_size']`` records how many actions the most recent chunk produced — useful for the panel later if we want to surface "chunks generated" alongside "dispatched". Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -92,17 +92,28 @@ class LowLevelForward(InferenceStep):
|
||||
if self.policy is None or self.observation_provider is None:
|
||||
return None
|
||||
if not state.get("task"):
|
||||
# No task yet → nothing useful to condition on.
|
||||
return None
|
||||
|
||||
# SmolVLA produces *action chunks* (typically 50 steps via
|
||||
# flow-matching). The expensive part is the chunk forward;
|
||||
# popping one action per dispatch tick is essentially free.
|
||||
# Only regenerate when the queue is low so we don't burn one
|
||||
# full chunk forward per chunk_hz tick when most of the
|
||||
# previous chunk is still buffered.
|
||||
queue = state.setdefault("action_queue", [])
|
||||
chunk_size = getattr(self.policy.config, "chunk_size", None) or getattr(
|
||||
self.policy.config, "n_action_steps", 50
|
||||
)
|
||||
# Refresh threshold: keep at least half a chunk buffered.
|
||||
if len(queue) > max(1, chunk_size // 2):
|
||||
return None
|
||||
|
||||
observation = self.observation_provider()
|
||||
if observation is None:
|
||||
return None
|
||||
# SmolVLA's ``select_action`` expects the full preprocessed
|
||||
# batch, including ``OBS_LANGUAGE_TOKENS`` /
|
||||
# ``OBS_LANGUAGE_ATTENTION_MASK``. The observation provider
|
||||
# only returns image / state features (the runtime drives
|
||||
# messages itself), so build a low-level prompt from current
|
||||
# runtime state and tokenize it inline.
|
||||
|
||||
# Same prompt construction as before — task + plan + memory,
|
||||
# optional current subtask — then merge into the obs batch.
|
||||
ctx = _control_context_messages(state)
|
||||
if state.get("current_subtask"):
|
||||
ctx = ctx + [{"role": "assistant", "content": state["current_subtask"]}]
|
||||
@@ -115,16 +126,38 @@ class LowLevelForward(InferenceStep):
|
||||
observation = dict(observation)
|
||||
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
|
||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
|
||||
|
||||
try:
|
||||
action = self.policy.select_action(observation)
|
||||
# ``predict_action_chunk`` returns the *full* chunk shape
|
||||
# ``(batch, n_action_steps, action_dim)``. Enqueue every
|
||||
# step so DispatchAction at ctrl_hz can drain them
|
||||
# smoothly until the next refresh.
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("select_action failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
push_log(state, f" [warn] select_action failed: {type(exc).__name__}: {exc}")
|
||||
logger.warning(
|
||||
"predict_action_chunk failed: %s",
|
||||
exc,
|
||||
exc_info=logger.isEnabledFor(logging.DEBUG),
|
||||
)
|
||||
push_log(
|
||||
state,
|
||||
f" [warn] predict_action_chunk failed: "
|
||||
f"{type(exc).__name__}: {exc}",
|
||||
)
|
||||
return None
|
||||
# SmolVLA returns a single action; if the underlying policy
|
||||
# streams chunks, split per-step here. For v1 we just enqueue
|
||||
# the result.
|
||||
state.setdefault("action_queue", []).append(action)
|
||||
|
||||
# ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push
|
||||
# each step as a ``(1, action_dim)`` tensor so the existing
|
||||
# action executor's batch-squeeze logic works unchanged.
|
||||
if chunk.ndim == 3:
|
||||
chunk_iter = chunk[0] # ``(n_action_steps, action_dim)``
|
||||
elif chunk.ndim == 2:
|
||||
chunk_iter = chunk
|
||||
else:
|
||||
chunk_iter = chunk.unsqueeze(0)
|
||||
for step in chunk_iter:
|
||||
queue.append(step.unsqueeze(0))
|
||||
state["last_chunk_size"] = int(chunk_iter.shape[0])
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user