feat(eval): multi-policy-server support for Docker eval

Add eval.policy_servers parameter (default 1) that spawns N independent
policy inference servers on consecutive ports. Containers are round-robin
assigned across servers, enabling parallel GPU inference for small models
like SmolVLA (~1.4GB each).

Usage: --eval.policy_servers=4 --eval.instance_count=20
  → 4 model copies on GPU, 20 containers distributed across them.
Made-with: Cursor
This commit is contained in:
Pepijn Kooijmans
2026-03-24 20:28:58 +01:00
parent b97ea8999f
commit b3c2592ace
2 changed files with 34 additions and 18 deletions
+6
View File
@@ -86,6 +86,10 @@ class EvalConfig:
# 0-indexed shard id for this process. Users usually leave this at 0.
# Additional shards are launched automatically by `lerobot-eval` when instance_count > 1.
instance_id: int = 0
# Number of policy inference servers to run in parallel (Docker runtime only).
# Each server loads a copy of the model and listens on consecutive ports
# starting from eval.docker.port. Containers are round-robin assigned.
policy_servers: int = 1
def __post_init__(self) -> None:
if self.runtime not in {"local", "docker"}:
@@ -96,6 +100,8 @@ class EvalConfig:
raise ValueError(
f"eval.instance_id must be in [0, {self.instance_count - 1}] (got {self.instance_id})."
)
if self.policy_servers < 1:
raise ValueError("eval.policy_servers must be >= 1.")
if self.batch_size > self.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
+28 -18
View File
@@ -23,8 +23,8 @@ server for action chunks.
Architecture:
host (GPU):
1. Load policy + preprocessors from EvalPipelineConfig.
2. Start HTTP policy-inference server (one lock — serialises GPU calls).
3. Spawn ``instance_count`` Docker containers (one per shard).
2. Start ``policy_servers`` HTTP inference servers on consecutive ports.
3. Spawn ``instance_count`` Docker containers, round-robin assigned to servers.
4. Wait; collect per-task JSON written to the mounted output volume.
5. Merge shards → aggregate → write eval_info.json.
@@ -245,29 +245,38 @@ def run_eval_in_docker(cfg: EvalPipelineConfig) -> None:
policy_cfg=cfg.policy,
)
# ── Start HTTP inference server ───────────────────────────────────────
port = docker_cfg.port
server = _InferenceServer(
("0.0.0.0", port), # nosec B104 — only alive for the duration of eval
policy=policy,
env_preprocessor=env_preprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
server_thread = threading.Thread(target=server.serve_forever, daemon=True)
server_thread.start()
logger.info("Policy inference server running on port %d", port)
# ── Start HTTP inference server(s) ────────────────────────────────────
n_policy_servers = cfg.eval.policy_servers
base_port = docker_cfg.port
host_ip = _get_host_ip()
server_address = f"{host_ip}:{port}"
instance_count = cfg.eval.instance_count
env_argv = _env_argv()
# ── Spawn containers ──────────────────────────────────────────────────
servers: list[_InferenceServer] = []
for s_idx in range(n_policy_servers):
port = base_port + s_idx
if s_idx > 0:
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
policy.eval()
srv = _InferenceServer(
("0.0.0.0", port), # nosec B104
policy=policy,
env_preprocessor=env_preprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
t = threading.Thread(target=srv.serve_forever, daemon=True)
t.start()
servers.append(srv)
logger.info("Policy inference server %d/%d running on port %d", s_idx + 1, n_policy_servers, port)
# ── Spawn containers (round-robin across policy servers) ──────────────
container_dirs: list[Path] = []
procs: list[subprocess.Popen] = []
try:
for i in range(instance_count):
assigned_port = base_port + (i % n_policy_servers)
server_address = f"{host_ip}:{assigned_port}"
shard_dir = output_dir / "shards" / str(i)
container_dirs.append(shard_dir)
proc = _spawn_container(
@@ -293,7 +302,8 @@ def run_eval_in_docker(cfg: EvalPipelineConfig) -> None:
raise RuntimeError(f"Docker eval containers failed (instance_id, exit_code): {failed}")
finally:
server.shutdown()
for srv in servers:
srv.shutdown()
# ── Collect and merge per-task results ───────────────────────────────
per_task: list[dict] = []