diff --git a/src/lerobot/annotations/steerable_pipeline/config.py b/src/lerobot/annotations/steerable_pipeline/config.py index 297839d06..ff9a11f4e 100644 --- a/src/lerobot/annotations/steerable_pipeline/config.py +++ b/src/lerobot/annotations/steerable_pipeline/config.py @@ -108,10 +108,17 @@ class VlmConfig: this command (if present) is substituted per-replica.""" parallel_servers: int = 1 """When >1, spawn this many independent inference servers (each pinned - to one GPU via ``CUDA_VISIBLE_DEVICES`` and listening on + to a GPU via ``CUDA_VISIBLE_DEVICES`` and listening on ``serve_port + i``) and round-robin client requests across them. Useful when DP/TP NCCL setup is broken on the node — single-GPU - replicas don't need cross-GPU communication.""" + replicas don't need cross-GPU communication. When + ``parallel_servers > num_gpus``, replicas are round-robin-assigned + to GPUs (e.g. 4 replicas on 2 GPUs → 0,1,0,1).""" + num_gpus: int = 0 + """How many physical GPUs are available for round-robin replica + placement. ``0`` means ``parallel_servers`` (one GPU per replica, + backward-compatible default). Set this to ``2`` with + ``parallel_servers=4`` to pack 2 replicas per GPU.""" client_concurrency: int = 16 """Maximum number of in-flight chat requests the client issues in parallel. vllm batches them internally for free, so bumping this diff --git a/src/lerobot/annotations/steerable_pipeline/vlm_client.py b/src/lerobot/annotations/steerable_pipeline/vlm_client.py index d2659321b..fd18110d4 100644 --- a/src/lerobot/annotations/steerable_pipeline/vlm_client.py +++ b/src/lerobot/annotations/steerable_pipeline/vlm_client.py @@ -456,10 +456,12 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]: f"--uvicorn-log-level warning" ) + num_gpus = config.num_gpus if config.num_gpus > 0 else n for i in range(n): port = config.serve_port + i + gpu = i % num_gpus env = _os.environ.copy() - env["CUDA_VISIBLE_DEVICES"] = str(i) + env["CUDA_VISIBLE_DEVICES"] = str(gpu) cmd = base_cmd if "{port}" in cmd: cmd = cmd.replace("{port}", str(port)) @@ -467,7 +469,7 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]: cmd = f"{cmd} --port {port}" api_base = f"http://localhost:{port}/v1" api_bases.append(api_base) - print(f"[server-{i}] launching on GPU {i} port {port}: {cmd}", flush=True) + print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True) proc = subprocess.Popen( shlex.split(cmd), stdout=subprocess.PIPE,