diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index d8df0feee..ead4f799a 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -86,6 +86,10 @@ class EvalConfig: # 0-indexed shard id for this process. Users usually leave this at 0. # Additional shards are launched automatically by `lerobot-eval` when instance_count > 1. instance_id: int = 0 + # Number of policy inference servers to run in parallel (Docker runtime only). + # Each server loads a copy of the model and listens on consecutive ports + # starting from eval.docker.port. Containers are round-robin assigned. + policy_servers: int = 1 def __post_init__(self) -> None: if self.runtime not in {"local", "docker"}: @@ -96,6 +100,8 @@ class EvalConfig: raise ValueError( f"eval.instance_id must be in [0, {self.instance_count - 1}] (got {self.instance_id})." ) + if self.policy_servers < 1: + raise ValueError("eval.policy_servers must be >= 1.") if self.batch_size > self.n_episodes: raise ValueError( "The eval batch size is greater than the number of eval episodes " diff --git a/src/lerobot/envs/docker_runtime.py b/src/lerobot/envs/docker_runtime.py index 59f6879f8..882914d14 100644 --- a/src/lerobot/envs/docker_runtime.py +++ b/src/lerobot/envs/docker_runtime.py @@ -23,8 +23,8 @@ server for action chunks. Architecture: host (GPU): 1. Load policy + preprocessors from EvalPipelineConfig. - 2. Start HTTP policy-inference server (one lock — serialises GPU calls). - 3. Spawn ``instance_count`` Docker containers (one per shard). + 2. Start ``policy_servers`` HTTP inference servers on consecutive ports. + 3. Spawn ``instance_count`` Docker containers, round-robin assigned to servers. 4. Wait; collect per-task JSON written to the mounted output volume. 5. Merge shards → aggregate → write eval_info.json. @@ -245,29 +245,38 @@ def run_eval_in_docker(cfg: EvalPipelineConfig) -> None: policy_cfg=cfg.policy, ) - # ── Start HTTP inference server ─────────────────────────────────────── - port = docker_cfg.port - server = _InferenceServer( - ("0.0.0.0", port), # nosec B104 — only alive for the duration of eval - policy=policy, - env_preprocessor=env_preprocessor, - preprocessor=preprocessor, - postprocessor=postprocessor, - ) - server_thread = threading.Thread(target=server.serve_forever, daemon=True) - server_thread.start() - logger.info("Policy inference server running on port %d", port) - + # ── Start HTTP inference server(s) ──────────────────────────────────── + n_policy_servers = cfg.eval.policy_servers + base_port = docker_cfg.port host_ip = _get_host_ip() - server_address = f"{host_ip}:{port}" instance_count = cfg.eval.instance_count env_argv = _env_argv() - # ── Spawn containers ────────────────────────────────────────────────── + servers: list[_InferenceServer] = [] + for s_idx in range(n_policy_servers): + port = base_port + s_idx + if s_idx > 0: + policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map) + policy.eval() + srv = _InferenceServer( + ("0.0.0.0", port), # nosec B104 + policy=policy, + env_preprocessor=env_preprocessor, + preprocessor=preprocessor, + postprocessor=postprocessor, + ) + t = threading.Thread(target=srv.serve_forever, daemon=True) + t.start() + servers.append(srv) + logger.info("Policy inference server %d/%d running on port %d", s_idx + 1, n_policy_servers, port) + + # ── Spawn containers (round-robin across policy servers) ────────────── container_dirs: list[Path] = [] procs: list[subprocess.Popen] = [] try: for i in range(instance_count): + assigned_port = base_port + (i % n_policy_servers) + server_address = f"{host_ip}:{assigned_port}" shard_dir = output_dir / "shards" / str(i) container_dirs.append(shard_dir) proc = _spawn_container( @@ -293,7 +302,8 @@ def run_eval_in_docker(cfg: EvalPipelineConfig) -> None: raise RuntimeError(f"Docker eval containers failed (instance_id, exit_code): {failed}") finally: - server.shutdown() + for srv in servers: + srv.shutdown() # ── Collect and merge per-task results ─────────────────────────────── per_task: list[dict] = []