mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
Compare commits
14 Commits
a81e23b0e9
...
aded6214ab
| Author | SHA1 | Date | |
|---|---|---|---|
| aded6214ab | |||
| e70277ba3e | |||
| 4930338c52 | |||
| 55879e4fb4 | |||
| 0b2f0d1d6a | |||
| a27972125b | |||
| 70bdec72ef | |||
| de50eabd3f | |||
| 23845218b6 | |||
| 01fc975eb5 | |||
| fc4f6d2502 | |||
| e21996f23b | |||
| 10fa65a996 | |||
| 8f125a5ec1 |
@@ -35,6 +35,18 @@ class Module1Config:
|
||||
max_video_frames: int = 32
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
use_video_url: bool = False
|
||||
"""When True (and backend supports it, e.g. ``openai``), Module 1
|
||||
sends a ``video_url`` content block pointing at the episode's mp4
|
||||
file instead of pre-decoded frames. Lets the server sample frames at
|
||||
its own ``fps`` — no in-process conv3d cost. The video file is
|
||||
extracted as a per-episode subclip to ``staging/.video_clips/`` so
|
||||
the model sees only this episode's frames."""
|
||||
use_video_url_fps: float = 1.0
|
||||
"""Frame-rate hint to send to the server (mm_processor_kwargs.fps).
|
||||
Only used when ``use_video_url=True``. ``1.0`` = sample 1 frame per
|
||||
second, which is plenty for subtask-boundary detection on most
|
||||
manipulation episodes."""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -61,16 +73,50 @@ class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
backend: str = "vllm"
|
||||
"""One of ``vllm``, ``transformers``, or ``stub`` (tests only)."""
|
||||
"""One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests only).
|
||||
|
||||
The ``openai`` backend talks to any OpenAI-compatible server — works
|
||||
with ``vllm serve``, ``transformers serve``, ``ktransformers serve``,
|
||||
or hosted endpoints. Set ``api_base`` and (optionally) ``api_key``."""
|
||||
model_id: str = "Qwen/Qwen3.6-27B-FP8"
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
"""Base URL for the ``openai`` backend."""
|
||||
api_key: str = "EMPTY"
|
||||
"""API key for the ``openai`` backend; ``EMPTY`` works for local servers."""
|
||||
auto_serve: bool = True
|
||||
"""When True with ``backend=openai``, the CLI probes ``api_base``
|
||||
first; if no server answers, it spawns one (default:
|
||||
``transformers serve``), waits for it to be ready, runs the
|
||||
pipeline, and tears it down on exit. Default ``True`` so a single
|
||||
``lerobot-annotate`` call can drive the whole flow. Set to ``False``
|
||||
if you want to fail fast when no server is reachable (e.g. you're
|
||||
pointing at a remote endpoint that should already be up)."""
|
||||
serve_port: int = 8000
|
||||
"""Port the auto-spawned server binds to. Sets ``api_base`` automatically."""
|
||||
serve_command: str | None = None
|
||||
"""Override the auto-serve command (full shell command). When ``None``,
|
||||
we run ``transformers serve <model_id> --port <serve_port> --continuous-batching``."""
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
"""Max seconds to wait for the server to start serving requests."""
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
json_mode: bool = True
|
||||
batch_size: int = 4
|
||||
tensor_parallel_size: int = 1
|
||||
trust_remote_code: bool = True
|
||||
"""Pass ``trust_remote_code`` to HF auto-classes. Required for many
|
||||
newer VL checkpoints (Qwen3.x FP8, etc.) that ship custom loader code."""
|
||||
gpu_memory_utilization: float = 0.9
|
||||
"""Fraction of GPU memory vllm allocates for weights + KV cache.
|
||||
Lower (e.g. 0.7) when the vision encoder needs cuDNN workspace, or to
|
||||
avoid CUDNN_STATUS_NOT_INITIALIZED on tight VRAM (30B BF16 on 80 GB)."""
|
||||
max_model_len: int | None = None
|
||||
"""Cap context length. ``None`` keeps the model's default; on H100 80 GB
|
||||
a 30B BF16 model often needs ``max_model_len=8192`` or smaller to leave
|
||||
room for KV cache."""
|
||||
trust_remote_code: bool = False
|
||||
"""Pass ``trust_remote_code`` to HF auto-classes. Default ``False`` —
|
||||
only enable for models that actually ship custom code in their repo
|
||||
(rare for first-class VL releases). On Qwen3-VL it triggers an
|
||||
std::bad_alloc post-load even though the official transformers class
|
||||
is sufficient, so leaving this off is safest."""
|
||||
camera_key: str | None = None
|
||||
"""Override the camera stream used for keyframe attachment. ``None`` picks
|
||||
the first ``observation.images.*`` key the dataset declares."""
|
||||
|
||||
@@ -109,7 +109,10 @@ class VideoFrameProvider:
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses)
|
||||
for i, img in zip(miss_indices, decoded, strict=True):
|
||||
# decoder may return fewer frames than requested when some
|
||||
# timestamps fall outside the video; pair what we have and
|
||||
# leave the rest as None to be filtered below.
|
||||
for i, img in zip(miss_indices, decoded):
|
||||
out[i] = img
|
||||
key = (record.episode_index, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
@@ -119,6 +122,8 @@ class VideoFrameProvider:
|
||||
return [img for img in out if img is not None]
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]:
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
from PIL import Image # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415
|
||||
@@ -127,11 +132,17 @@ class VideoFrameProvider:
|
||||
from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key)
|
||||
# ``torchcodec`` import currently bad-allocs on cu128/torch-2.8 in
|
||||
# some environments; default to ``pyav`` (always available via
|
||||
# the ``av`` package) and let users override with
|
||||
# LEROBOT_VIDEO_BACKEND=torchcodec when their stack supports it.
|
||||
backend = _os.environ.get("LEROBOT_VIDEO_BACKEND", "pyav")
|
||||
try:
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted,
|
||||
self.tolerance_s,
|
||||
backend=backend,
|
||||
return_uint8=True,
|
||||
)
|
||||
except Exception:
|
||||
@@ -192,3 +203,62 @@ def to_video_block(images: list[Any]) -> list[dict[str, Any]]:
|
||||
if not images:
|
||||
return []
|
||||
return [{"type": "video", "video": list(images)}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
|
||||
|
||||
def episode_clip_path(
|
||||
record: EpisodeRecord,
|
||||
provider: "VideoFrameProvider",
|
||||
cache_dir: Path,
|
||||
) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks. Skips re-extract
|
||||
when the cached clip already exists. Uses ``ffmpeg`` via subprocess
|
||||
with stream-copy where possible (no re-encode) for speed.
|
||||
"""
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
if provider.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = provider._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{provider.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{provider.camera_key}/to_timestamp"])
|
||||
src = provider.root / provider._meta.get_video_file_path(
|
||||
record.episode_index, provider.camera_key
|
||||
)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-ss",
|
||||
f"{from_timestamp:.3f}",
|
||||
"-to",
|
||||
f"{to_timestamp:.3f}",
|
||||
"-i",
|
||||
str(src),
|
||||
"-c",
|
||||
"copy",
|
||||
str(out_path),
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, check=True, timeout=120)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
@@ -21,8 +21,17 @@ from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ..config import Module1Config
|
||||
from ..frames import FrameProvider, null_provider, to_video_block
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
VideoFrameProvider,
|
||||
episode_clip_path,
|
||||
null_provider,
|
||||
to_video_block,
|
||||
to_video_url_block,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
@@ -151,14 +160,26 @@ class PlanSubtasksMemoryModule:
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
video_frames = self.frame_provider.video_for_episode(record, self.config.max_video_frames)
|
||||
prompt = load_prompt("module_1_subtasks").format(
|
||||
episode_task=record.episode_task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
)
|
||||
content = [*to_video_block(video_frames), {"type": "text", "text": prompt}]
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = episode_clip_path(record, self.frame_provider, cache_dir)
|
||||
video_block = (
|
||||
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||
if clip is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
video_frames = self.frame_provider.video_for_episode(
|
||||
record, self.config.max_video_frames
|
||||
)
|
||||
video_block = to_video_block(video_frames)
|
||||
content = [*video_block, {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
spans = result.get("subtasks") if isinstance(result, dict) else None
|
||||
|
||||
@@ -138,6 +138,8 @@ def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
return _make_vllm_client(config)
|
||||
if config.backend == "transformers":
|
||||
return _make_transformers_client(config)
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
@@ -148,16 +150,35 @@ def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
||||
raise ImportError(
|
||||
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
|
||||
) from exc
|
||||
llm = LLM(model=config.model_id, tensor_parallel_size=config.tensor_parallel_size)
|
||||
# Workaround for cuDNN 9.x + torch 2.8 conv3d regression that surfaces
|
||||
# as CUDNN_STATUS_NOT_INITIALIZED in Qwen-VL vision-tower patch
|
||||
# embedders. Setting LEROBOT_DISABLE_CUDNN=1 forces native PyTorch
|
||||
# convolution kernels — slower but functional.
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
if _os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
_torch.backends.cudnn.enabled = False
|
||||
llm_kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"tensor_parallel_size": config.tensor_parallel_size,
|
||||
"gpu_memory_utilization": config.gpu_memory_utilization,
|
||||
"trust_remote_code": config.trust_remote_code,
|
||||
}
|
||||
if config.max_model_len is not None:
|
||||
llm_kwargs["max_model_len"] = config.max_model_len
|
||||
llm = LLM(**llm_kwargs)
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
params = SamplingParams(
|
||||
max_tokens=max_tok,
|
||||
temperature=temp,
|
||||
guided_decoding={"json": {}} if config.json_mode else None,
|
||||
)
|
||||
prompts = [_messages_to_prompt(m) for m in batch]
|
||||
outputs = llm.generate(prompts, params)
|
||||
# ``guided_decoding`` would speed up parsing but its API differs across
|
||||
# vllm releases (dict vs GuidedDecodingParams). The _GenericTextClient
|
||||
# wrapper already has a one-retry JSON-recovery path, so we skip it.
|
||||
params = SamplingParams(max_tokens=max_tok, temperature=temp)
|
||||
# ``llm.chat`` handles chat-template application + multimodal input
|
||||
# extraction (image/video blocks) internally, which ``llm.generate``
|
||||
# does not.
|
||||
outputs = llm.chat([list(m) for m in batch], params)
|
||||
return [o.outputs[0].text for o in outputs]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
@@ -183,20 +204,32 @@ def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
config.model_id, trust_remote_code=config.trust_remote_code
|
||||
)
|
||||
# ``low_cpu_mem_usage=True`` avoids a transformers-internal staging
|
||||
# buffer that has caused std::bad_alloc on Qwen3-line architectures
|
||||
# even on hosts with TBs of RAM (the failing alloc is in the
|
||||
# post-load tensor-placement path, not a real OOM).
|
||||
# ``device_map='auto'`` then streams shards directly to the GPU.
|
||||
# ``trust_remote_code`` is required for many newer VL releases
|
||||
# (Qwen3.6-FP8, etc.) that ship a custom loader in the repo.
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
use_accelerate = _os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
|
||||
# ``device_map='auto'`` triggers a known std::bad_alloc on the Qwen3-VL
|
||||
# post-load dispatch path (the alloc fails in accelerate's hook setup
|
||||
# even with TBs of host RAM). Default to manual: load on CPU with
|
||||
# ``low_cpu_mem_usage=True``, then ``.to("cuda")``. Set
|
||||
# ``LEROBOT_TRANSFORMERS_DEVICE_MAP=auto`` to opt back into the old path.
|
||||
if use_accelerate:
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype=_torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
model = model.to("cuda")
|
||||
model.eval()
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
@@ -220,6 +253,196 @@ def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. "
|
||||
"Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={config.auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if config.auto_serve:
|
||||
if _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
client = OpenAI(base_url=api_base, api_key=config.api_key)
|
||||
|
||||
def _gen(
|
||||
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
|
||||
) -> list[str]:
|
||||
outs: list[str] = []
|
||||
for messages in batch:
|
||||
api_messages = [_to_openai_message(m) for m in messages]
|
||||
response = client.chat.completions.create(
|
||||
model=config.model_id,
|
||||
messages=api_messages,
|
||||
max_tokens=max_tok,
|
||||
temperature=temp,
|
||||
)
|
||||
outs.append(response.choices[0].message.content or "")
|
||||
return outs
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp:
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
import atexit # noqa: PLC0415
|
||||
import shlex # noqa: PLC0415
|
||||
import signal # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
import sys # noqa: PLC0415
|
||||
import threading # noqa: PLC0415
|
||||
import time # noqa: PLC0415
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
ready_markers = ("Uvicorn running", "Application startup complete")
|
||||
|
||||
def _stream_output() -> None:
|
||||
assert proc.stdout is not None
|
||||
for line in proc.stdout:
|
||||
sys.stdout.write(f"[server] {line}")
|
||||
sys.stdout.flush()
|
||||
if any(marker in line for marker in ready_markers):
|
||||
ready_event.set()
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(
|
||||
f"[server] did not become ready within {config.serve_ready_timeout_s}s"
|
||||
)
|
||||
|
||||
|
||||
def _to_openai_message(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert an internal message dict to OpenAI chat format.
|
||||
|
||||
Internal image/video blocks (using PIL.Image objects) become
|
||||
OpenAI ``image_url``/``video_url`` items via base64 data URLs.
|
||||
"""
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
return {"role": message["role"], "content": content}
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}}
|
||||
)
|
||||
elif block_type == "video_url":
|
||||
# Pass through to the OpenAI-compatible server unchanged.
|
||||
out_blocks.append({"type": "video_url", "video_url": block["video_url"]})
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
return {"role": message["role"], "content": out_blocks}
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
import base64 # noqa: PLC0415
|
||||
import io # noqa: PLC0415
|
||||
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def _messages_to_prompt(messages: Sequence[dict[str, Any]]) -> Any:
|
||||
"""Pass-through hook used by the vllm backend.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user