mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 575b4b70aa | |||
| b6eb651bab |
+32
@@ -0,0 +1,32 @@
|
||||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
|
||||
# Configure environment variables
|
||||
ARG PYTHON_VERSION
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV MUJOCO_GL="egl"
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install dependencies and set up Python in a single layer
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake git \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher libgeos-dev \
|
||||
&& ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
|
||||
&& python -m venv /opt/venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Clone repository and install LeRobot in a single layer
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
|
||||
&& /opt/venv/bin/pip install --no-cache-dir ".[async, smolvla]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Execute in bash shell rather than python
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
EXPOSE 8080 8081
|
||||
CMD ["python", "-m", "lerobot.scripts.server.policy_server", "--host=0.0.0.0", "--port=8080"]
|
||||
@@ -27,3 +27,7 @@ SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
|
||||
|
||||
"""Networking support"""
|
||||
HEALTH_CHECK_PORT = 8081
|
||||
HEALTH_SERVER_HOST = "0.0.0.0" # nosec
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
from typing import Any
|
||||
@@ -206,6 +208,83 @@ def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def create_health_handler(policy_server):
|
||||
"""Factory function to create health handler with policy server reference."""
|
||||
|
||||
def handler(*args, **kwargs):
|
||||
return HealthHandler(policy_server, *args, **kwargs)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
class HealthHandler(BaseHTTPRequestHandler):
|
||||
"""HTTP handler for health checks."""
|
||||
|
||||
def __init__(self, policy_server, *args, **kwargs):
|
||||
self.policy_server = policy_server
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def do_GET(self): # noqa: N802
|
||||
"""Handle GET requests for health check."""
|
||||
if self.path == "/health":
|
||||
self.send_health_response()
|
||||
elif self.path == "/":
|
||||
self.send_info_response()
|
||||
else:
|
||||
self.send_error(404, "Not Found")
|
||||
|
||||
def send_health_response(self):
|
||||
"""Send health check response."""
|
||||
try:
|
||||
# Check if the policy server is in a healthy state
|
||||
is_healthy = (
|
||||
hasattr(self.policy_server, "_running_event")
|
||||
and self.policy_server._running_event is not None
|
||||
)
|
||||
|
||||
status_code = 200 if is_healthy else 503
|
||||
response = {
|
||||
"status": "healthy" if is_healthy else "unhealthy",
|
||||
"timestamp": time.time(),
|
||||
"server_running": self.policy_server.running
|
||||
if hasattr(self.policy_server, "running")
|
||||
else False,
|
||||
"policy_loaded": self.policy_server.policy is not None
|
||||
if hasattr(self.policy_server, "policy")
|
||||
else False,
|
||||
}
|
||||
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode())
|
||||
|
||||
except Exception as e:
|
||||
self.send_response(500)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
error_response = {"status": "error", "message": str(e)}
|
||||
self.wfile.write(json.dumps(error_response).encode())
|
||||
|
||||
def send_info_response(self):
|
||||
"""Send basic server info."""
|
||||
response = {
|
||||
"service": "lerobot-policy-server",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {"health": "/health", "grpc_port": self.policy_server.config.port},
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode())
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""Override to use our logger instead of stderr."""
|
||||
if hasattr(self.policy_server, "logger"):
|
||||
self.policy_server.logger.debug(f"HTTP: {format % args}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedData:
|
||||
"""A data object with timestamp and timestep information.
|
||||
|
||||
@@ -30,6 +30,7 @@ import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from http.server import HTTPServer
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
|
||||
@@ -39,13 +40,14 @@ import torch
|
||||
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
|
||||
from lerobot.scripts.server.constants import HEALTH_CHECK_PORT, HEALTH_SERVER_HOST, SUPPORTED_POLICIES
|
||||
from lerobot.scripts.server.helpers import (
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
create_health_handler,
|
||||
get_logger,
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
@@ -82,6 +84,30 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
|
||||
# HTTP health server
|
||||
self.http_server = None
|
||||
self.http_thread = None
|
||||
|
||||
def start_health_server(self):
|
||||
"""Start HTTP server for health checks on port 8081."""
|
||||
try:
|
||||
health_handler = create_health_handler(self)
|
||||
self.http_server = HTTPServer((HEALTH_SERVER_HOST, HEALTH_CHECK_PORT), health_handler)
|
||||
self.http_thread = threading.Thread(target=self.http_server.serve_forever, daemon=True)
|
||||
self.http_thread.start()
|
||||
self.logger.info(f"Health server started on port {HEALTH_CHECK_PORT}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to start health server: {e}")
|
||||
|
||||
def stop_health_server(self):
|
||||
"""Stop the HTTP health server."""
|
||||
if self.http_server:
|
||||
self.http_server.shutdown()
|
||||
self.http_server.server_close()
|
||||
if self.http_thread:
|
||||
self.http_thread.join(timeout=5)
|
||||
self.logger.info("Health server stopped")
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running_event.is_set()
|
||||
@@ -372,6 +398,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
"""Stop the server"""
|
||||
self._reset_server()
|
||||
self.logger.info("Server stopping...")
|
||||
self.stop_health_server()
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
@@ -385,6 +412,7 @@ def serve(cfg: PolicyServerConfig):
|
||||
|
||||
# Create the server instance first
|
||||
policy_server = PolicyServer(cfg)
|
||||
policy_server.start_health_server()
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
@@ -397,6 +425,7 @@ def serve(cfg: PolicyServerConfig):
|
||||
server.wait_for_termination()
|
||||
|
||||
policy_server.logger.info("Server terminated")
|
||||
policy_server.stop_health_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user