mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
fix(annotate): use device_map='auto' for transformers backend
Without device_map, transformers stages the full FP8 checkpoint in CPU RAM before any GPU placement, OOMing the host on 27B+ models even when the GPU has enough VRAM. device_map='auto' streams shards directly to GPU memory. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -181,7 +181,14 @@ def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
|||||||
"for VL models."
|
"for VL models."
|
||||||
)
|
)
|
||||||
processor = AutoProcessor.from_pretrained(config.model_id)
|
processor = AutoProcessor.from_pretrained(config.model_id)
|
||||||
model = auto_cls.from_pretrained(config.model_id, torch_dtype="auto")
|
# device_map='auto' loads weights directly to GPU(s) and shards when
|
||||||
|
# needed; without it, transformers stages the full checkpoint in CPU
|
||||||
|
# memory first which OOMs the host on FP8/large models.
|
||||||
|
model = auto_cls.from_pretrained(
|
||||||
|
config.model_id,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||||
|
|||||||
Reference in New Issue
Block a user