mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 21:50:03 +00:00
feat(eval): implement docker runtime with HTTP policy inference server
Add docker_runtime.py (host-side) and lerobot_eval_worker.py (container-side) for --eval.runtime=docker. Policy loads once on the host GPU; Docker containers run env-only workers that call back via HTTP for action chunks, maximising GPU utilisation across parallel benchmark tasks. - _InferenceServer: HTTP server wrapping predict_action_chunk with a single lock - run_eval_in_docker: spawns instance_count containers, collects + merges per-task JSON, writes eval_info.json compatible with _aggregate_eval_from_per_task - lerobot-eval-worker CLI: make_env → shard tasks → run episodes → write JSON - EvalDockerConfig: add port field (default 50051) - pyproject.toml: add lerobot-eval-worker entry point Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -247,6 +247,7 @@ lerobot-replay="lerobot.scripts.lerobot_replay:main"
|
|||||||
lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main"
|
lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main"
|
||||||
lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main"
|
lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main"
|
||||||
lerobot-eval="lerobot.scripts.lerobot_eval:main"
|
lerobot-eval="lerobot.scripts.lerobot_eval:main"
|
||||||
|
lerobot-eval-worker="lerobot.scripts.lerobot_eval_worker:main"
|
||||||
lerobot-train="lerobot.scripts.lerobot_train:main"
|
lerobot-train="lerobot.scripts.lerobot_train:main"
|
||||||
lerobot-train-tokenizer="lerobot.scripts.lerobot_train_tokenizer:main"
|
lerobot-train-tokenizer="lerobot.scripts.lerobot_train_tokenizer:main"
|
||||||
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ class EvalDockerConfig:
|
|||||||
gpus: str | None = "all"
|
gpus: str | None = "all"
|
||||||
# Docker --shm-size value (increase when using larger eval.batch_size values).
|
# Docker --shm-size value (increase when using larger eval.batch_size values).
|
||||||
shm_size: str = "8g"
|
shm_size: str = "8g"
|
||||||
|
# Port on which the host HTTP policy inference server listens.
|
||||||
|
port: int = 50051
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -0,0 +1,312 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Docker runtime for lerobot-eval.
|
||||||
|
|
||||||
|
The policy stays on the host GPU; gym environments run inside Docker containers.
|
||||||
|
Each container runs `lerobot-eval-worker`, which calls back to a host HTTP inference
|
||||||
|
server for action chunks.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
host (GPU):
|
||||||
|
1. Load policy + preprocessors from EvalPipelineConfig.
|
||||||
|
2. Start HTTP policy-inference server (one lock — serialises GPU calls).
|
||||||
|
3. Spawn ``instance_count`` Docker containers (one per shard).
|
||||||
|
4. Wait; collect per-task JSON written to the mounted output volume.
|
||||||
|
5. Merge shards → aggregate → write eval_info.json.
|
||||||
|
|
||||||
|
container (CPU only):
|
||||||
|
1. make_env(cfg.env) → shard tasks by (instance_id, instance_count).
|
||||||
|
2. For each task: run n_episodes, POST obs to /predict_chunk, step env.
|
||||||
|
3. Write per-task JSON to /results/worker_{instance_id}.json.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import pickle # nosec B403 — internal serialisation only
|
||||||
|
import platform
|
||||||
|
import subprocess # nosec B404
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.envs.factory import make_env_pre_post_processors
|
||||||
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
|
from lerobot.utils.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from lerobot.configs.eval import EvalPipelineConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HTTP inference server (host side)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _PolicyInferenceHandler(BaseHTTPRequestHandler):
|
||||||
|
"""POST /predict_chunk → pickled numpy action chunk."""
|
||||||
|
|
||||||
|
server: _InferenceServer
|
||||||
|
|
||||||
|
def do_POST(self) -> None:
|
||||||
|
if self.path != "/predict_chunk":
|
||||||
|
self.send_error(404)
|
||||||
|
return
|
||||||
|
length = int(self.headers["Content-Length"])
|
||||||
|
body = self.rfile.read(length)
|
||||||
|
payload: dict = pickle.loads(body) # nosec B301
|
||||||
|
obs_t: dict = payload["obs_t"]
|
||||||
|
|
||||||
|
with self.server._lock:
|
||||||
|
chunk_np = self.server._predict(obs_t)
|
||||||
|
|
||||||
|
resp = pickle.dumps(chunk_np) # nosec B301
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header("Content-Type", "application/octet-stream")
|
||||||
|
self.send_header("Content-Length", str(len(resp)))
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(resp)
|
||||||
|
|
||||||
|
def log_message(self, fmt: str, *args: Any) -> None: # noqa: ANN401
|
||||||
|
pass # suppress per-request logs
|
||||||
|
|
||||||
|
|
||||||
|
class _InferenceServer(HTTPServer):
|
||||||
|
"""Wraps the loaded policy behind a trivial HTTP interface."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
addr: tuple[str, int],
|
||||||
|
policy: Any,
|
||||||
|
env_preprocessor: Any,
|
||||||
|
preprocessor: Any,
|
||||||
|
postprocessor: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(addr, _PolicyInferenceHandler)
|
||||||
|
self._policy = policy
|
||||||
|
self._env_preprocessor = env_preprocessor
|
||||||
|
self._preprocessor = preprocessor
|
||||||
|
self._postprocessor = postprocessor
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._device = torch.device(str(policy.config.device))
|
||||||
|
|
||||||
|
def _predict(self, obs_t: dict) -> np.ndarray:
|
||||||
|
"""Apply full preprocessing pipeline and return (T, A) numpy chunk."""
|
||||||
|
obs = self._env_preprocessor(obs_t)
|
||||||
|
obs = self._preprocessor(obs)
|
||||||
|
obs_gpu: dict = {k: v.to(self._device) if isinstance(v, torch.Tensor) else v for k, v in obs.items()}
|
||||||
|
with torch.no_grad():
|
||||||
|
chunk: torch.Tensor = self._policy.predict_action_chunk(obs_gpu) # (B, T, A)
|
||||||
|
|
||||||
|
# Postprocessor expects (B, A); apply it treating each timestep as a batch element.
|
||||||
|
# For linear transforms (unnormalize) this is identical to applying it to (B, T, A).
|
||||||
|
batch, n_steps, action_dim = chunk.shape
|
||||||
|
chunk_2d = chunk.reshape(batch * n_steps, action_dim) # (B*T, A)
|
||||||
|
chunk_2d = self._postprocessor(chunk_2d) # (B*T, A)
|
||||||
|
# Return only the first env's chunk — batch_size=1 per container.
|
||||||
|
return chunk_2d[:n_steps].cpu().numpy() # (T, A)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _get_host_ip() -> str:
|
||||||
|
"""Return the IP that containers can use to reach the host."""
|
||||||
|
if platform.system() in ("Darwin", "Windows"):
|
||||||
|
return "host.docker.internal"
|
||||||
|
return "172.17.0.1" # Linux Docker bridge default gateway
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_image(cfg: EvalPipelineConfig) -> str:
|
||||||
|
"""Return the Docker image name to use for the env containers."""
|
||||||
|
if cfg.eval.docker.image:
|
||||||
|
return cfg.eval.docker.image
|
||||||
|
return f"lerobot-benchmark-{cfg.env.type}"
|
||||||
|
|
||||||
|
|
||||||
|
def _env_argv() -> list[str]:
|
||||||
|
"""Extract --env.* args from sys.argv to forward verbatim to the worker."""
|
||||||
|
return [arg for arg in sys.argv[1:] if arg.startswith("--env.")]
|
||||||
|
|
||||||
|
|
||||||
|
def _spawn_container(
|
||||||
|
*,
|
||||||
|
image: str,
|
||||||
|
instance_id: int,
|
||||||
|
instance_count: int,
|
||||||
|
server_address: str,
|
||||||
|
n_episodes: int,
|
||||||
|
seed: int,
|
||||||
|
output_dir: Path,
|
||||||
|
docker_cfg: Any,
|
||||||
|
env_argv: list[str],
|
||||||
|
) -> subprocess.Popen:
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
container_results = "/results"
|
||||||
|
|
||||||
|
cmd: list[str] = ["docker", "run", "--rm"]
|
||||||
|
if docker_cfg.gpus:
|
||||||
|
cmd += [f"--gpus={docker_cfg.gpus}"]
|
||||||
|
cmd += [f"--shm-size={docker_cfg.shm_size}"]
|
||||||
|
cmd += ["-v", f"{output_dir.resolve()}:{container_results}"]
|
||||||
|
# Allow containers on Linux to resolve host.docker.internal.
|
||||||
|
cmd += ["--add-host=host.docker.internal:host-gateway"]
|
||||||
|
cmd.append(image)
|
||||||
|
|
||||||
|
cmd += [
|
||||||
|
"lerobot-eval-worker",
|
||||||
|
*env_argv,
|
||||||
|
f"--server_address={server_address}",
|
||||||
|
f"--n_episodes={n_episodes}",
|
||||||
|
f"--seed={seed}",
|
||||||
|
f"--instance_id={instance_id}",
|
||||||
|
f"--instance_count={instance_count}",
|
||||||
|
f"--output_path={container_results}/worker_{instance_id}.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Spawning container %d/%d: %s",
|
||||||
|
instance_id + 1,
|
||||||
|
instance_count,
|
||||||
|
" ".join(cmd),
|
||||||
|
)
|
||||||
|
return subprocess.Popen(cmd) # nosec B603 B607
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def run_eval_in_docker(cfg: EvalPipelineConfig) -> None:
|
||||||
|
"""Run eval with env in Docker containers and policy on the host GPU.
|
||||||
|
|
||||||
|
Writes ``eval_info.json`` to ``cfg.output_dir``. Called by
|
||||||
|
``lerobot_eval._run_eval_worker`` when ``eval.runtime == "docker"``.
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular import at module level.
|
||||||
|
from lerobot.scripts.lerobot_eval import _aggregate_eval_from_per_task
|
||||||
|
|
||||||
|
start_t = time.time()
|
||||||
|
output_dir = Path(cfg.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
docker_cfg = cfg.eval.docker
|
||||||
|
|
||||||
|
# Optionally pull the image before starting.
|
||||||
|
image = _resolve_image(cfg)
|
||||||
|
if docker_cfg.pull:
|
||||||
|
logger.info("Pulling Docker image: %s", image)
|
||||||
|
subprocess.run(["docker", "pull", image], check=True) # nosec B603 B607
|
||||||
|
|
||||||
|
# ── Load policy + all preprocessors on the host GPU ──────────────────
|
||||||
|
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
|
||||||
|
policy.eval()
|
||||||
|
|
||||||
|
preprocessor_overrides: dict = {
|
||||||
|
"device_processor": {"device": str(device)},
|
||||||
|
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||||
|
}
|
||||||
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
|
policy_cfg=cfg.policy,
|
||||||
|
pretrained_path=cfg.policy.pretrained_path,
|
||||||
|
preprocessor_overrides=preprocessor_overrides,
|
||||||
|
)
|
||||||
|
env_preprocessor, _env_postprocessor = make_env_pre_post_processors(
|
||||||
|
env_cfg=cfg.env,
|
||||||
|
policy_cfg=cfg.policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Start HTTP inference server ───────────────────────────────────────
|
||||||
|
port = docker_cfg.port
|
||||||
|
server = _InferenceServer(
|
||||||
|
("0.0.0.0", port), # nosec B104 — only alive for the duration of eval
|
||||||
|
policy=policy,
|
||||||
|
env_preprocessor=env_preprocessor,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
)
|
||||||
|
server_thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||||
|
server_thread.start()
|
||||||
|
logger.info("Policy inference server running on port %d", port)
|
||||||
|
|
||||||
|
host_ip = _get_host_ip()
|
||||||
|
server_address = f"{host_ip}:{port}"
|
||||||
|
instance_count = cfg.eval.instance_count
|
||||||
|
env_argv = _env_argv()
|
||||||
|
|
||||||
|
# ── Spawn containers ──────────────────────────────────────────────────
|
||||||
|
container_dirs: list[Path] = []
|
||||||
|
procs: list[subprocess.Popen] = []
|
||||||
|
try:
|
||||||
|
for i in range(instance_count):
|
||||||
|
shard_dir = output_dir / "shards" / str(i)
|
||||||
|
container_dirs.append(shard_dir)
|
||||||
|
proc = _spawn_container(
|
||||||
|
image=image,
|
||||||
|
instance_id=i,
|
||||||
|
instance_count=instance_count,
|
||||||
|
server_address=server_address,
|
||||||
|
n_episodes=cfg.eval.n_episodes,
|
||||||
|
seed=cfg.seed,
|
||||||
|
output_dir=shard_dir,
|
||||||
|
docker_cfg=docker_cfg,
|
||||||
|
env_argv=env_argv,
|
||||||
|
)
|
||||||
|
procs.append(proc)
|
||||||
|
|
||||||
|
failed: list[tuple[int, int]] = []
|
||||||
|
for i, proc in enumerate(procs):
|
||||||
|
rc = proc.wait()
|
||||||
|
if rc != 0:
|
||||||
|
failed.append((i, rc))
|
||||||
|
logger.error("Container %d/%d exited with code %d", i + 1, instance_count, rc)
|
||||||
|
if failed:
|
||||||
|
raise RuntimeError(f"Docker eval containers failed (instance_id, exit_code): {failed}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
server.shutdown()
|
||||||
|
|
||||||
|
# ── Collect and merge per-task results ───────────────────────────────
|
||||||
|
per_task: list[dict] = []
|
||||||
|
for i, shard_dir in enumerate(container_dirs):
|
||||||
|
result_file = shard_dir / f"worker_{i}.json"
|
||||||
|
with open(result_file) as f:
|
||||||
|
shard_data: dict = json.load(f)
|
||||||
|
per_task.extend(shard_data.get("per_task", []))
|
||||||
|
|
||||||
|
per_task.sort(key=lambda x: (x["task_group"], x["task_id"]))
|
||||||
|
|
||||||
|
info = _aggregate_eval_from_per_task(per_task, total_eval_s=time.time() - start_t)
|
||||||
|
with open(output_dir / "eval_info.json", "w") as f:
|
||||||
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
|
logger.info("Docker eval complete. Results: %s/eval_info.json", output_dir)
|
||||||
@@ -0,0 +1,192 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Docker eval worker — runs inside a benchmark container.
|
||||||
|
|
||||||
|
Runs gym episodes for a sharded subset of the configured env's tasks, calling
|
||||||
|
a remote HTTP policy inference server (running on the host GPU) for action chunks.
|
||||||
|
|
||||||
|
Usage (normally invoked by docker_runtime.run_eval_in_docker, not directly):
|
||||||
|
lerobot-eval-worker \\
|
||||||
|
--env.type=libero_plus \\
|
||||||
|
--server_address=host.docker.internal:50051 \\
|
||||||
|
--n_episodes=5 \\
|
||||||
|
--seed=1000 \\
|
||||||
|
--instance_id=0 \\
|
||||||
|
--instance_count=2 \\
|
||||||
|
--output_path=/results/worker_0.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import pickle # nosec B403 — internal serialisation only
|
||||||
|
import urllib.request
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot import envs # noqa: F401 — registers all env subclasses
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.envs.configs import EnvConfig
|
||||||
|
from lerobot.envs.factory import make_env
|
||||||
|
from lerobot.envs.utils import add_envs_task, preprocess_observation
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalWorkerConfig:
|
||||||
|
env: EnvConfig
|
||||||
|
# Address of the policy inference HTTP server on the host.
|
||||||
|
server_address: str = "host.docker.internal:50051"
|
||||||
|
# Number of episodes to run per task.
|
||||||
|
n_episodes: int = 1
|
||||||
|
# Starting random seed; episode i of a task uses seed + i.
|
||||||
|
seed: int = 0
|
||||||
|
# 0-indexed shard id for this worker.
|
||||||
|
instance_id: int = 0
|
||||||
|
# Total number of shards (workers).
|
||||||
|
instance_count: int = 1
|
||||||
|
# Path (inside the container) to write the JSON per-task results.
|
||||||
|
output_path: Path = field(default_factory=lambda: Path("/results/worker.json"))
|
||||||
|
# Timeout in seconds for each HTTP request to the policy server.
|
||||||
|
server_timeout: float = 120.0
|
||||||
|
|
||||||
|
|
||||||
|
def _call_server(server_address: str, obs_t: dict, timeout: float) -> np.ndarray:
|
||||||
|
"""POST pickled obs to /predict_chunk, return numpy chunk (T, action_dim)."""
|
||||||
|
body = pickle.dumps({"obs_t": obs_t}) # nosec B301
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"http://{server_address}/predict_chunk",
|
||||||
|
data=body,
|
||||||
|
method="POST",
|
||||||
|
headers={"Content-Type": "application/octet-stream"},
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=timeout) as resp: # nosec B310
|
||||||
|
return pickle.loads(resp.read()) # nosec B301
|
||||||
|
|
||||||
|
|
||||||
|
def run_worker(cfg: EvalWorkerConfig) -> dict:
|
||||||
|
"""Run cfg.n_episodes episodes per assigned task. Returns per-task results dict."""
|
||||||
|
# Build envs: {task_group: {task_id: vec_env}}
|
||||||
|
envs_dict = make_env(cfg.env, n_envs=1)
|
||||||
|
|
||||||
|
# Flatten to list of (task_group, task_id, env)
|
||||||
|
tasks = [
|
||||||
|
(task_group, task_id, vec)
|
||||||
|
for task_group, group in envs_dict.items()
|
||||||
|
for task_id, vec in group.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Shard: this worker handles tasks where index % instance_count == instance_id
|
||||||
|
if cfg.instance_count > 1:
|
||||||
|
total = len(tasks)
|
||||||
|
tasks = [t for idx, t in enumerate(tasks) if idx % cfg.instance_count == cfg.instance_id]
|
||||||
|
logger.info(
|
||||||
|
"Shard %d/%d: %d/%d tasks assigned.",
|
||||||
|
cfg.instance_id + 1,
|
||||||
|
cfg.instance_count,
|
||||||
|
len(tasks),
|
||||||
|
total,
|
||||||
|
)
|
||||||
|
|
||||||
|
per_task: list[dict] = []
|
||||||
|
|
||||||
|
for task_group, task_id, env in tasks:
|
||||||
|
sum_rewards: list[float] = []
|
||||||
|
max_rewards: list[float] = []
|
||||||
|
successes: list[bool] = []
|
||||||
|
|
||||||
|
for ep_idx in range(cfg.n_episodes):
|
||||||
|
obs, _info = env.reset(seed=[cfg.seed + ep_idx])
|
||||||
|
obs_t = preprocess_observation(obs)
|
||||||
|
obs_t = add_envs_task(env, obs_t)
|
||||||
|
|
||||||
|
action_buffer: list[np.ndarray] = [] # each element: (1, action_dim)
|
||||||
|
ep_rewards: list[float] = []
|
||||||
|
ep_success = False
|
||||||
|
done = np.zeros(1, dtype=bool)
|
||||||
|
|
||||||
|
while not np.all(done):
|
||||||
|
if not action_buffer:
|
||||||
|
chunk_np = _call_server(cfg.server_address, obs_t, cfg.server_timeout)
|
||||||
|
# chunk_np: (T, action_dim) — split into per-step slices of shape (1, action_dim)
|
||||||
|
action_buffer = [chunk_np[i : i + 1] for i in range(chunk_np.shape[0])]
|
||||||
|
|
||||||
|
action_np = action_buffer.pop(0) # (1, action_dim)
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action_np)
|
||||||
|
|
||||||
|
done = terminated | truncated | done
|
||||||
|
ep_rewards.append(float(np.mean(reward)))
|
||||||
|
|
||||||
|
if "final_info" in info:
|
||||||
|
final_info = info["final_info"]
|
||||||
|
if isinstance(final_info, dict) and "is_success" in final_info:
|
||||||
|
ep_success = bool(final_info["is_success"][0])
|
||||||
|
|
||||||
|
if not np.all(done):
|
||||||
|
obs_t = preprocess_observation(obs)
|
||||||
|
obs_t = add_envs_task(env, obs_t)
|
||||||
|
|
||||||
|
sum_rewards.append(float(np.sum(ep_rewards)))
|
||||||
|
max_rewards.append(float(np.max(ep_rewards)) if ep_rewards else 0.0)
|
||||||
|
successes.append(ep_success)
|
||||||
|
logger.info(
|
||||||
|
"Task %s[%d] ep %d/%d — success=%s",
|
||||||
|
task_group,
|
||||||
|
task_id,
|
||||||
|
ep_idx + 1,
|
||||||
|
cfg.n_episodes,
|
||||||
|
ep_success,
|
||||||
|
)
|
||||||
|
|
||||||
|
per_task.append(
|
||||||
|
{
|
||||||
|
"task_group": task_group,
|
||||||
|
"task_id": task_id,
|
||||||
|
"metrics": {
|
||||||
|
"sum_rewards": sum_rewards,
|
||||||
|
"max_rewards": max_rewards,
|
||||||
|
"successes": successes,
|
||||||
|
"video_paths": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
return {"per_task": per_task}
|
||||||
|
|
||||||
|
|
||||||
|
def worker_main(cfg: EvalWorkerConfig) -> None:
|
||||||
|
results = run_worker(cfg)
|
||||||
|
output = Path(cfg.output_path)
|
||||||
|
output.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
output.write_text(json.dumps(results, indent=2))
|
||||||
|
logger.info("Worker %d wrote results to %s", cfg.instance_id, output)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
init_logging()
|
||||||
|
cfg = parser.parse(EvalWorkerConfig)
|
||||||
|
worker_main(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user