fix(annotate): expose gpu_memory_utilization and max_model_len for vllm

Large VL models (Qwen3-VL-30B-A3B BF16) take ~58 GB of an 80 GB H100,
leaving only ~22 GB for KV cache + cuDNN workspace. The vision tower's
3D conv then fails with CUDNN_STATUS_NOT_INITIALIZED because cuDNN
can't grab a workspace large enough.

- vlm.gpu_memory_utilization (default 0.9) — drop to 0.7 when the vision
  encoder needs more cuDNN workspace.
- vlm.max_model_len — cap context to free KV cache memory; the 262k
  default for Qwen3 is wildly more than annotation prompts need.
- vlm.trust_remote_code — already plumbed; now also passed to LLM().

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-28 11:29:22 +02:00
parent a81e23b0e9
commit 8f125a5ec1
2 changed files with 17 additions and 1 deletions
@@ -68,6 +68,14 @@ class VlmConfig:
json_mode: bool = True
batch_size: int = 4
tensor_parallel_size: int = 1
gpu_memory_utilization: float = 0.9
"""Fraction of GPU memory vllm allocates for weights + KV cache.
Lower (e.g. 0.7) when the vision encoder needs cuDNN workspace, or to
avoid CUDNN_STATUS_NOT_INITIALIZED on tight VRAM (30B BF16 on 80 GB)."""
max_model_len: int | None = None
"""Cap context length. ``None`` keeps the model's default; on H100 80 GB
a 30B BF16 model often needs ``max_model_len=8192`` or smaller to leave
room for KV cache."""
trust_remote_code: bool = True
"""Pass ``trust_remote_code`` to HF auto-classes. Required for many
newer VL checkpoints (Qwen3.x FP8, etc.) that ship custom loader code."""
@@ -148,7 +148,15 @@ def _make_vllm_client(config: VlmConfig) -> VlmClient:
raise ImportError(
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
) from exc
llm = LLM(model=config.model_id, tensor_parallel_size=config.tensor_parallel_size)
llm_kwargs: dict[str, Any] = {
"model": config.model_id,
"tensor_parallel_size": config.tensor_parallel_size,
"gpu_memory_utilization": config.gpu_memory_utilization,
"trust_remote_code": config.trust_remote_code,
}
if config.max_model_len is not None:
llm_kwargs["max_model_len"] = config.max_model_len
llm = LLM(**llm_kwargs)
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
params = SamplingParams(