feat(annotate): pack multiple vllm replicas per GPU via num_gpus

Adds VlmConfig.num_gpus so parallel_servers can exceed the physical
GPU count. Replicas are round-robin-assigned to GPUs (e.g.
parallel_servers=4 + num_gpus=2 → replicas pinned to GPUs 0,1,0,1).
Backward-compatible: num_gpus=0 keeps the existing 1-replica-per-GPU
behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-29 16:11:51 +02:00
parent bda829af2f
commit ac183f019e
2 changed files with 13 additions and 4 deletions
@@ -108,10 +108,17 @@ class VlmConfig:
this command (if present) is substituted per-replica."""
parallel_servers: int = 1
"""When >1, spawn this many independent inference servers (each pinned
to one GPU via ``CUDA_VISIBLE_DEVICES`` and listening on
to a GPU via ``CUDA_VISIBLE_DEVICES`` and listening on
``serve_port + i``) and round-robin client requests across them.
Useful when DP/TP NCCL setup is broken on the node — single-GPU
replicas don't need cross-GPU communication."""
replicas don't need cross-GPU communication. When
``parallel_servers > num_gpus``, replicas are round-robin-assigned
to GPUs (e.g. 4 replicas on 2 GPUs → 0,1,0,1)."""
num_gpus: int = 0
"""How many physical GPUs are available for round-robin replica
placement. ``0`` means ``parallel_servers`` (one GPU per replica,
backward-compatible default). Set this to ``2`` with
``parallel_servers=4`` to pack 2 replicas per GPU."""
client_concurrency: int = 16
"""Maximum number of in-flight chat requests the client issues in
parallel. vllm batches them internally for free, so bumping this
@@ -456,10 +456,12 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
f"--uvicorn-log-level warning"
)
num_gpus = config.num_gpus if config.num_gpus > 0 else n
for i in range(n):
port = config.serve_port + i
gpu = i % num_gpus
env = _os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(i)
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
cmd = base_cmd
if "{port}" in cmd:
cmd = cmd.replace("{port}", str(port))
@@ -467,7 +469,7 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
cmd = f"{cmd} --port {port}"
api_base = f"http://localhost:{port}/v1"
api_bases.append(api_base)
print(f"[server-{i}] launching on GPU {i} port {port}: {cmd}", flush=True)
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
proc = subprocess.Popen(
shlex.split(cmd),
stdout=subprocess.PIPE,