Compare commits

...

2 Commits

Author SHA1 Message Date
Francesco Capuano 575b4b70aa wip(docker) 2025-07-10 18:47:28 +02:00
Francesco Capuano b6eb651bab add healthy route 2025-07-10 17:46:26 +02:00
4 changed files with 145 additions and 1 deletions
+32
View File
@@ -0,0 +1,32 @@
# Configure image
ARG PYTHON_VERSION=3.10
FROM python:${PYTHON_VERSION}-slim
# Configure environment variables
ARG PYTHON_VERSION
ENV DEBIAN_FRONTEND=noninteractive
ENV MUJOCO_GL="egl"
ENV PATH="/opt/venv/bin:$PATH"
# Install dependencies and set up Python in a single layer
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential cmake git \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
speech-dispatcher libgeos-dev \
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
&& python -m venv /opt/venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
# Clone repository and install LeRobot in a single layer
COPY . /lerobot
WORKDIR /lerobot
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
&& /opt/venv/bin/pip install --no-cache-dir ".[async, smolvla]" \
--extra-index-url https://download.pytorch.org/whl/cpu
# Execute in bash shell rather than python
CMD ["/bin/bash"]
EXPOSE 8080 8081
CMD ["python", "-m", "lerobot.scripts.server.policy_server", "--host=0.0.0.0", "--port=8080"]
+4
View File
@@ -27,3 +27,7 @@ SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
"""Networking support"""
HEALTH_CHECK_PORT = 8081
HEALTH_SERVER_HOST = "0.0.0.0" # nosec
+79
View File
@@ -13,11 +13,13 @@
# limitations under the License.
import io
import json
import logging
import logging.handlers
import os
import time
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler
from pathlib import Path
from threading import Event
from typing import Any
@@ -206,6 +208,83 @@ def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
return logging.getLogger(name)
def create_health_handler(policy_server):
"""Factory function to create health handler with policy server reference."""
def handler(*args, **kwargs):
return HealthHandler(policy_server, *args, **kwargs)
return handler
class HealthHandler(BaseHTTPRequestHandler):
"""HTTP handler for health checks."""
def __init__(self, policy_server, *args, **kwargs):
self.policy_server = policy_server
super().__init__(*args, **kwargs)
def do_GET(self): # noqa: N802
"""Handle GET requests for health check."""
if self.path == "/health":
self.send_health_response()
elif self.path == "/":
self.send_info_response()
else:
self.send_error(404, "Not Found")
def send_health_response(self):
"""Send health check response."""
try:
# Check if the policy server is in a healthy state
is_healthy = (
hasattr(self.policy_server, "_running_event")
and self.policy_server._running_event is not None
)
status_code = 200 if is_healthy else 503
response = {
"status": "healthy" if is_healthy else "unhealthy",
"timestamp": time.time(),
"server_running": self.policy_server.running
if hasattr(self.policy_server, "running")
else False,
"policy_loaded": self.policy_server.policy is not None
if hasattr(self.policy_server, "policy")
else False,
}
self.send_response(status_code)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(response).encode())
except Exception as e:
self.send_response(500)
self.send_header("Content-type", "application/json")
self.end_headers()
error_response = {"status": "error", "message": str(e)}
self.wfile.write(json.dumps(error_response).encode())
def send_info_response(self):
"""Send basic server info."""
response = {
"service": "lerobot-policy-server",
"version": "1.0.0",
"endpoints": {"health": "/health", "grpc_port": self.policy_server.config.port},
}
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(response).encode())
def log_message(self, format, *args):
"""Override to use our logger instead of stderr."""
if hasattr(self.policy_server, "logger"):
self.policy_server.logger.debug(f"HTTP: {format % args}")
@dataclass
class TimedData:
"""A data object with timestamp and timestep information.
+30 -1
View File
@@ -30,6 +30,7 @@ import threading
import time
from concurrent import futures
from dataclasses import asdict
from http.server import HTTPServer
from pprint import pformat
from queue import Empty, Queue
@@ -39,13 +40,14 @@ import torch
from lerobot.policies.factory import get_policy_class
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
from lerobot.scripts.server.constants import HEALTH_CHECK_PORT, HEALTH_SERVER_HOST, SUPPORTED_POLICIES
from lerobot.scripts.server.helpers import (
FPSTracker,
Observation,
RemotePolicyConfig,
TimedAction,
TimedObservation,
create_health_handler,
get_logger,
observations_similar,
raw_observation_to_observation,
@@ -82,6 +84,30 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
self.actions_per_chunk = None
self.policy = None
# HTTP health server
self.http_server = None
self.http_thread = None
def start_health_server(self):
"""Start HTTP server for health checks on port 8081."""
try:
health_handler = create_health_handler(self)
self.http_server = HTTPServer((HEALTH_SERVER_HOST, HEALTH_CHECK_PORT), health_handler)
self.http_thread = threading.Thread(target=self.http_server.serve_forever, daemon=True)
self.http_thread.start()
self.logger.info(f"Health server started on port {HEALTH_CHECK_PORT}")
except Exception as e:
self.logger.error(f"Failed to start health server: {e}")
def stop_health_server(self):
"""Stop the HTTP health server."""
if self.http_server:
self.http_server.shutdown()
self.http_server.server_close()
if self.http_thread:
self.http_thread.join(timeout=5)
self.logger.info("Health server stopped")
@property
def running(self):
return self._running_event.is_set()
@@ -372,6 +398,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
"""Stop the server"""
self._reset_server()
self.logger.info("Server stopping...")
self.stop_health_server()
@draccus.wrap()
@@ -385,6 +412,7 @@ def serve(cfg: PolicyServerConfig):
# Create the server instance first
policy_server = PolicyServer(cfg)
policy_server.start_health_server()
# Setup and start gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
@@ -397,6 +425,7 @@ def serve(cfg: PolicyServerConfig):
server.wait_for_termination()
policy_server.logger.info("Server terminated")
policy_server.stop_health_server()
if __name__ == "__main__":