feat(runtime): --subtask_chunks_per_gen throttles HL gen vs action chunks

Adds a per-chunk-boundary counter to HighLevelSubtaskFwd: subtask gen
fires only once every N chunk boundaries (default 1 = current
behavior). Lets the operator run e.g. 5 flow-matching action chunks
per LM-head subtask gen so the subtask doesn't churn every 1.7s while
the previous one is still being executed — saves compute and avoids
re-planning the action trajectory mid-grasp.

  --subtask_chunks_per_gen=5    # 5 chunks per subtask refresh

The counter starts at 0 so the very first chunk boundary fires
immediately (no startup delay). Trigger is rearmed when skipping so
a low high_level_hz doesn't lose slots.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-25 12:34:59 +02:00
parent db927ab40b
commit 793c7c4ddd
2 changed files with 39 additions and 0 deletions
@@ -515,6 +515,25 @@ class HighLevelSubtaskFwd(InferenceStep):
if hasattr(self.trigger, "rearm"):
self.trigger.rearm()
return None
# Per-chunk-boundary throttle: at each "queue empty" moment we
# increment a counter; subtask gen only fires once the counter
# reaches ``subtask_chunks_per_gen``. Lets the operator run e.g.
# 5 action chunks per subtask-gen so the LM head doesn't churn
# every 1.7 s (a fresh subtask while the previous one is still
# being executed is wasted compute *and* causes the action
# expert's flow trajectory to be re-planned mid-grasp).
chunks_per_gen = max(1, int(state.get("subtask_chunks_per_gen", 1) or 1))
# Initialise so the first chunk boundary fires immediately
# (counter starts at chunks_per_gen, decrements per skip,
# generates and resets when it hits 0).
if "_hl_chunks_until_gen" not in state:
state["_hl_chunks_until_gen"] = 0
if state["_hl_chunks_until_gen"] > 0:
state["_hl_chunks_until_gen"] -= 1
if hasattr(self.trigger, "rearm"):
self.trigger.rearm()
return None
state["_hl_chunks_until_gen"] = chunks_per_gen - 1
ctx = _msgs_for_subtask(state)
observation = _maybe_observation(self.observation_provider)
# Default: greedy argmax, no min_new_tokens, no special-token
@@ -272,6 +272,18 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
default=1.0,
help="High-level subtask generation rate.",
)
p.add_argument(
"--subtask_chunks_per_gen",
type=int,
default=1,
help=(
"Throttle subtask gen to once every N action-chunk boundaries. "
"Default 1 = regenerate the subtask on every chunk refresh. "
"Set to 5 to run ~5 flow-matching action chunks per LM-head "
"subtask gen — saves compute and avoids re-planning trajectories "
"mid-grasp when a subtask is still valid across multiple chunks."
),
)
p.add_argument(
"--max_ticks",
type=int,
@@ -1514,6 +1526,14 @@ def main(argv: list[str] | None = None) -> int:
# robot actually receive" in one log line.
runtime.state["_postprocessor"] = postprocessor
runtime.state["text_gen_top_p"] = float(getattr(args, "text_top_p", 1.0) or 1.0)
# Subtask throttle: HighLevelSubtaskFwd fires only once every N
# action-chunk boundaries. Lets you run N action chunks per LM-head
# subtask gen (e.g. ``--subtask_chunks_per_gen=5`` ≈ 5 flow-matching
# chunks per subtask refresh) so the subtask doesn't churn while
# the previous one is still being executed.
runtime.state["subtask_chunks_per_gen"] = max(
1, int(getattr(args, "subtask_chunks_per_gen", 1) or 1)
)
# Apply the startup mode chosen above the task picker.
runtime.state["mode"] = startup_mode
if args.task: