mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
feat(eval): multi-policy-server support for Docker eval
Add eval.policy_servers parameter (default 1) that spawns N independent policy inference servers on consecutive ports. Containers are round-robin assigned across servers, enabling parallel GPU inference for small models like SmolVLA (~1.4GB each). Usage: --eval.policy_servers=4 --eval.instance_count=20 → 4 model copies on GPU, 20 containers distributed across them. Made-with: Cursor
This commit is contained in:
@@ -86,6 +86,10 @@ class EvalConfig:
|
||||
# 0-indexed shard id for this process. Users usually leave this at 0.
|
||||
# Additional shards are launched automatically by `lerobot-eval` when instance_count > 1.
|
||||
instance_id: int = 0
|
||||
# Number of policy inference servers to run in parallel (Docker runtime only).
|
||||
# Each server loads a copy of the model and listens on consecutive ports
|
||||
# starting from eval.docker.port. Containers are round-robin assigned.
|
||||
policy_servers: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.runtime not in {"local", "docker"}:
|
||||
@@ -96,6 +100,8 @@ class EvalConfig:
|
||||
raise ValueError(
|
||||
f"eval.instance_id must be in [0, {self.instance_count - 1}] (got {self.instance_id})."
|
||||
)
|
||||
if self.policy_servers < 1:
|
||||
raise ValueError("eval.policy_servers must be >= 1.")
|
||||
if self.batch_size > self.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
|
||||
@@ -23,8 +23,8 @@ server for action chunks.
|
||||
Architecture:
|
||||
host (GPU):
|
||||
1. Load policy + preprocessors from EvalPipelineConfig.
|
||||
2. Start HTTP policy-inference server (one lock — serialises GPU calls).
|
||||
3. Spawn ``instance_count`` Docker containers (one per shard).
|
||||
2. Start ``policy_servers`` HTTP inference servers on consecutive ports.
|
||||
3. Spawn ``instance_count`` Docker containers, round-robin assigned to servers.
|
||||
4. Wait; collect per-task JSON written to the mounted output volume.
|
||||
5. Merge shards → aggregate → write eval_info.json.
|
||||
|
||||
@@ -245,29 +245,38 @@ def run_eval_in_docker(cfg: EvalPipelineConfig) -> None:
|
||||
policy_cfg=cfg.policy,
|
||||
)
|
||||
|
||||
# ── Start HTTP inference server ───────────────────────────────────────
|
||||
port = docker_cfg.port
|
||||
server = _InferenceServer(
|
||||
("0.0.0.0", port), # nosec B104 — only alive for the duration of eval
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
server_thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||
server_thread.start()
|
||||
logger.info("Policy inference server running on port %d", port)
|
||||
|
||||
# ── Start HTTP inference server(s) ────────────────────────────────────
|
||||
n_policy_servers = cfg.eval.policy_servers
|
||||
base_port = docker_cfg.port
|
||||
host_ip = _get_host_ip()
|
||||
server_address = f"{host_ip}:{port}"
|
||||
instance_count = cfg.eval.instance_count
|
||||
env_argv = _env_argv()
|
||||
|
||||
# ── Spawn containers ──────────────────────────────────────────────────
|
||||
servers: list[_InferenceServer] = []
|
||||
for s_idx in range(n_policy_servers):
|
||||
port = base_port + s_idx
|
||||
if s_idx > 0:
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
|
||||
policy.eval()
|
||||
srv = _InferenceServer(
|
||||
("0.0.0.0", port), # nosec B104
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
t = threading.Thread(target=srv.serve_forever, daemon=True)
|
||||
t.start()
|
||||
servers.append(srv)
|
||||
logger.info("Policy inference server %d/%d running on port %d", s_idx + 1, n_policy_servers, port)
|
||||
|
||||
# ── Spawn containers (round-robin across policy servers) ──────────────
|
||||
container_dirs: list[Path] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
try:
|
||||
for i in range(instance_count):
|
||||
assigned_port = base_port + (i % n_policy_servers)
|
||||
server_address = f"{host_ip}:{assigned_port}"
|
||||
shard_dir = output_dir / "shards" / str(i)
|
||||
container_dirs.append(shard_dir)
|
||||
proc = _spawn_container(
|
||||
@@ -293,7 +302,8 @@ def run_eval_in_docker(cfg: EvalPipelineConfig) -> None:
|
||||
raise RuntimeError(f"Docker eval containers failed (instance_id, exit_code): {failed}")
|
||||
|
||||
finally:
|
||||
server.shutdown()
|
||||
for srv in servers:
|
||||
srv.shutdown()
|
||||
|
||||
# ── Collect and merge per-task results ───────────────────────────────
|
||||
per_task: list[dict] = []
|
||||
|
||||
Reference in New Issue
Block a user