From 47f2ea17bb6ceecbad53a89f5662c8e805155248 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 28 Apr 2026 23:06:20 +0200 Subject: [PATCH] feat(annotate): parallel_servers spawns N independent vllm replicas Adds --vlm.parallel_servers=N. Spawns N independent vllm processes (each pinned to GPU i via CUDA_VISIBLE_DEVICES, listening on serve_port+i) and round-robins requests across them. Sidesteps DP/TP NCCL setup failures on nodes with restricted P2P/SHM. Default serve_command for parallel mode: vllm serve --tensor-parallel-size 1 --max-model-len 32768 --uvicorn-log-level warning. Override via --vlm.serve_command (use {port} placeholder). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../annotations/steerable_pipeline/config.py | 11 +- .../steerable_pipeline/vlm_client.py | 128 +++++++++++++++++- 2 files changed, 135 insertions(+), 4 deletions(-) diff --git a/src/lerobot/annotations/steerable_pipeline/config.py b/src/lerobot/annotations/steerable_pipeline/config.py index b6c463ea6..0aa515174 100644 --- a/src/lerobot/annotations/steerable_pipeline/config.py +++ b/src/lerobot/annotations/steerable_pipeline/config.py @@ -111,7 +111,16 @@ class VlmConfig: """Port the auto-spawned server binds to. Sets ``api_base`` automatically.""" serve_command: str | None = None """Override the auto-serve command (full shell command). When ``None``, - we run ``transformers serve --port --continuous-batching``.""" + we run ``transformers serve --port --continuous-batching``. + + When ``parallel_servers > 1``, the literal ``{port}`` placeholder in + this command (if present) is substituted per-replica.""" + parallel_servers: int = 1 + """When >1, spawn this many independent inference servers (each pinned + to one GPU via ``CUDA_VISIBLE_DEVICES`` and listening on + ``serve_port + i``) and round-robin client requests across them. + Useful when DP/TP NCCL setup is broken on the node — single-GPU + replicas don't need cross-GPU communication.""" serve_ready_timeout_s: float = 600.0 """Max seconds to wait for the server to start serving requests.""" max_new_tokens: int = 512 diff --git a/src/lerobot/annotations/steerable_pipeline/vlm_client.py b/src/lerobot/annotations/steerable_pipeline/vlm_client.py index a06631c2f..9151c081c 100644 --- a/src/lerobot/annotations/steerable_pipeline/vlm_client.py +++ b/src/lerobot/annotations/steerable_pipeline/vlm_client.py @@ -333,6 +333,7 @@ def _make_openai_client(config: VlmConfig) -> VlmClient: api_base = config.api_base api_key = config.api_key auto_serve = config.auto_serve + api_bases: list[str] = [api_base] if config.use_hf_inference_providers: api_base = "https://router.huggingface.co/v1" @@ -363,14 +364,24 @@ def _make_openai_client(config: VlmConfig) -> VlmClient: flush=True, ) if auto_serve: - if _server_is_up(api_base): + if config.parallel_servers > 1: + print( + f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers", + flush=True, + ) + api_bases = _spawn_parallel_inference_servers(config) + elif _server_is_up(api_base): print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True) else: print("[lerobot-annotate] no server reachable; spawning one", flush=True) api_base = _spawn_inference_server(config) + api_bases = [api_base] print(f"[lerobot-annotate] server ready at {api_base}", flush=True) - client = OpenAI(base_url=api_base, api_key=api_key) + clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases] + client = clients[0] + # round-robin counter for parallel mode + rr_counter = {"i": 0} # ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve # rejects it with HTTP 422. Send it only when explicitly opted in via @@ -395,13 +406,124 @@ def _make_openai_client(config: VlmConfig) -> VlmClient: kwargs["extra_body"] = { "mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True} } - response = client.chat.completions.create(**kwargs) + chosen = clients[rr_counter["i"] % len(clients)] + rr_counter["i"] += 1 + response = chosen.chat.completions.create(**kwargs) outs.append(response.choices[0].message.content or "") return outs return _GenericTextClient(_gen, config) +def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]: + """Spawn ``config.parallel_servers`` independent vllm replicas. + + Each replica: + - is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES`` + - listens on ``serve_port + i`` + - is shut down via the same atexit hook as the single-server path + + Returns the list of ``api_base`` URLs the client should round-robin + across. + """ + import atexit # noqa: PLC0415 + import os as _os # noqa: PLC0415 + import shlex # noqa: PLC0415 + import signal # noqa: PLC0415 + import subprocess # noqa: PLC0415 + import sys # noqa: PLC0415 + import threading # noqa: PLC0415 + import time # noqa: PLC0415 + + n = config.parallel_servers + api_bases: list[str] = [] + procs: list[subprocess.Popen] = [] + ready_events: list[threading.Event] = [] + ready_markers = ("Uvicorn running", "Application startup complete") + + base_cmd = config.serve_command or ( + f"vllm serve {shlex.quote(config.model_id)} " + f"--tensor-parallel-size 1 " + f"--max-model-len {config.max_model_len or 32768} " + f"--uvicorn-log-level warning" + ) + + for i in range(n): + port = config.serve_port + i + env = _os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(i) + cmd = base_cmd + if "{port}" in cmd: + cmd = cmd.replace("{port}", str(port)) + else: + cmd = f"{cmd} --port {port}" + api_base = f"http://localhost:{port}/v1" + api_bases.append(api_base) + print(f"[server-{i}] launching on GPU {i} port {port}: {cmd}", flush=True) + proc = subprocess.Popen( + shlex.split(cmd), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + env=env, + ) + procs.append(proc) + ready = threading.Event() + ready_events.append(ready) + + def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None: + assert p.stdout is not None + buf = "" + tag_emitted = False + while True: + ch = p.stdout.read(1) + if ch == "": + return + if not tag_emitted: + sys.stdout.write(f"[server-{idx}] ") + tag_emitted = True + sys.stdout.write(ch) + sys.stdout.flush() + buf += ch + if ch in ("\n", "\r"): + if any(m in buf for m in ready_markers): + ev.set() + buf = "" + tag_emitted = False + + threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start() + + def _shutdown() -> None: + for i, p in enumerate(procs): + if p.poll() is None: + print(f"[server-{i}] stopping pid={p.pid}", flush=True) + p.send_signal(signal.SIGINT) + for p in procs: + try: + p.wait(timeout=15) + except subprocess.TimeoutExpired: + p.kill() + p.wait(timeout=5) + + atexit.register(_shutdown) + + deadline = time.monotonic() + config.serve_ready_timeout_s + while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline: + for i, p in enumerate(procs): + if p.poll() is not None: + raise RuntimeError( + f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}" + ) + time.sleep(2) + if any(not ev.is_set() for ev in ready_events): + raise RuntimeError( + f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s" + ) + print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True) + return api_bases + + def _server_is_up(api_base: str) -> bool: """Return True if ``api_base/models`` answers 200 within 2 seconds.""" import urllib.request # noqa: PLC0415