feat(annotate): parallelize episodes within each module phase

Saturates parallel_servers + client_concurrency. Previously the
executor processed one episode at a time, so each Module 1 episode's
3-5 dependent VLM calls hit a single server with the others idle. Now
defaults to 16 episodes in flight; configurable via
ExecutorConfig.episode_parallelism.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-28 23:59:02 +02:00
parent 994ad880ee
commit d0ad7ffb21
2 changed files with 40 additions and 10 deletions
@@ -162,6 +162,13 @@ class ExecutorConfig:
slurm_gpus: int = 1
slurm_time: str = "06:00:00"
workers: int = 1
episode_parallelism: int = 16
"""Number of episodes processed concurrently within each module phase.
Each in-flight episode sends 35 dependent VLM calls; bumping this is
how you actually saturate ``parallel_servers`` and ``client_concurrency``
— without it, the executor loops one episode at a time and the
inference servers sit ~90% idle. Set to ``1`` for strict serial
execution."""
@dataclass
@@ -136,25 +136,48 @@ class Executor:
module: Any,
) -> PhaseResult:
import time as _time # noqa: PLC0415
from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415
if not module.enabled:
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
n = len(records)
print(f"[annotate] phase={name} starting on {n} episode(s)", flush=True)
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
print(
f"[annotate] phase={name} starting on {n} episode(s) "
f"(parallelism={parallelism})",
flush=True,
)
t0 = _time.time()
processed = 0
for i, record in enumerate(records, 1):
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
i, record = idx_record
ep_start = _time.time()
staging = EpisodeStaging(staging_dir, record.episode_index)
module.run_episode(record, staging)
processed += 1
elapsed = _time.time() - ep_start
print(
f"[annotate] {name} episode {i}/{n} "
f"(idx={record.episode_index}) done in {elapsed:.1f}s",
flush=True,
)
return i, record.episode_index, _time.time() - ep_start
processed = 0
if parallelism == 1:
for i, record in enumerate(records, 1):
_, ep_idx, elapsed = _do((i, record))
processed += 1
print(
f"[annotate] {name} episode {i}/{n} "
f"(idx={ep_idx}) done in {elapsed:.1f}s",
flush=True,
)
else:
with ThreadPoolExecutor(max_workers=parallelism) as pool:
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
for fut in as_completed(futures):
i, ep_idx, elapsed = fut.result()
processed += 1
print(
f"[annotate] {name} episode {processed}/{n} "
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
flush=True,
)
total = _time.time() - t0
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)