mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
fix(smolvla2): enqueue full chunk via predict_action_chunk
``LowLevelForward`` was calling ``select_action()`` once per ``chunk_hz`` tick. SmolVLA's ``select_action`` is a thin queue-pop: it returns one action per call and only re-runs the expensive flow-matching forward when its private internal queue empties. Result: we got one action back per chunk_hz tick (1Hz default), ``DispatchAction`` at ctrl_hz=30 popped it instantly, then queue sat empty for ~1s waiting for the next tick. Net throughput was 1 dispatched action/sec instead of the 30 we wanted. Switch to ``predict_action_chunk`` and enqueue every step of the returned ``(batch, n_action_steps, action_dim)`` chunk. Refresh only when the queue is below half a chunk so we don't burn one flow-matching forward per chunk_hz tick — saves ~5x inference cost on this hot path. At ctrl_hz=30, chunk_size=50, the queue drains in ~1.7s before the next refresh, giving smooth dispatch at the control rate the robot was trained on. Side effect: ``state['last_chunk_size']`` records how many actions the most recent chunk produced — useful for the panel later if we want to surface "chunks generated" alongside "dispatched". Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -92,17 +92,28 @@ class LowLevelForward(InferenceStep):
|
|||||||
if self.policy is None or self.observation_provider is None:
|
if self.policy is None or self.observation_provider is None:
|
||||||
return None
|
return None
|
||||||
if not state.get("task"):
|
if not state.get("task"):
|
||||||
# No task yet → nothing useful to condition on.
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# SmolVLA produces *action chunks* (typically 50 steps via
|
||||||
|
# flow-matching). The expensive part is the chunk forward;
|
||||||
|
# popping one action per dispatch tick is essentially free.
|
||||||
|
# Only regenerate when the queue is low so we don't burn one
|
||||||
|
# full chunk forward per chunk_hz tick when most of the
|
||||||
|
# previous chunk is still buffered.
|
||||||
|
queue = state.setdefault("action_queue", [])
|
||||||
|
chunk_size = getattr(self.policy.config, "chunk_size", None) or getattr(
|
||||||
|
self.policy.config, "n_action_steps", 50
|
||||||
|
)
|
||||||
|
# Refresh threshold: keep at least half a chunk buffered.
|
||||||
|
if len(queue) > max(1, chunk_size // 2):
|
||||||
|
return None
|
||||||
|
|
||||||
observation = self.observation_provider()
|
observation = self.observation_provider()
|
||||||
if observation is None:
|
if observation is None:
|
||||||
return None
|
return None
|
||||||
# SmolVLA's ``select_action`` expects the full preprocessed
|
|
||||||
# batch, including ``OBS_LANGUAGE_TOKENS`` /
|
# Same prompt construction as before — task + plan + memory,
|
||||||
# ``OBS_LANGUAGE_ATTENTION_MASK``. The observation provider
|
# optional current subtask — then merge into the obs batch.
|
||||||
# only returns image / state features (the runtime drives
|
|
||||||
# messages itself), so build a low-level prompt from current
|
|
||||||
# runtime state and tokenize it inline.
|
|
||||||
ctx = _control_context_messages(state)
|
ctx = _control_context_messages(state)
|
||||||
if state.get("current_subtask"):
|
if state.get("current_subtask"):
|
||||||
ctx = ctx + [{"role": "assistant", "content": state["current_subtask"]}]
|
ctx = ctx + [{"role": "assistant", "content": state["current_subtask"]}]
|
||||||
@@ -115,16 +126,38 @@ class LowLevelForward(InferenceStep):
|
|||||||
observation = dict(observation)
|
observation = dict(observation)
|
||||||
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
|
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
|
||||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
|
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
action = self.policy.select_action(observation)
|
# ``predict_action_chunk`` returns the *full* chunk shape
|
||||||
|
# ``(batch, n_action_steps, action_dim)``. Enqueue every
|
||||||
|
# step so DispatchAction at ctrl_hz can drain them
|
||||||
|
# smoothly until the next refresh.
|
||||||
|
chunk = self.policy.predict_action_chunk(observation)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.warning("select_action failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
logger.warning(
|
||||||
push_log(state, f" [warn] select_action failed: {type(exc).__name__}: {exc}")
|
"predict_action_chunk failed: %s",
|
||||||
|
exc,
|
||||||
|
exc_info=logger.isEnabledFor(logging.DEBUG),
|
||||||
|
)
|
||||||
|
push_log(
|
||||||
|
state,
|
||||||
|
f" [warn] predict_action_chunk failed: "
|
||||||
|
f"{type(exc).__name__}: {exc}",
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
# SmolVLA returns a single action; if the underlying policy
|
|
||||||
# streams chunks, split per-step here. For v1 we just enqueue
|
# ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push
|
||||||
# the result.
|
# each step as a ``(1, action_dim)`` tensor so the existing
|
||||||
state.setdefault("action_queue", []).append(action)
|
# action executor's batch-squeeze logic works unchanged.
|
||||||
|
if chunk.ndim == 3:
|
||||||
|
chunk_iter = chunk[0] # ``(n_action_steps, action_dim)``
|
||||||
|
elif chunk.ndim == 2:
|
||||||
|
chunk_iter = chunk
|
||||||
|
else:
|
||||||
|
chunk_iter = chunk.unsqueeze(0)
|
||||||
|
for step in chunk_iter:
|
||||||
|
queue.append(step.unsqueeze(0))
|
||||||
|
state["last_chunk_size"] = int(chunk_iter.shape[0])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user