mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
more fixes
This commit is contained in:
@@ -1,15 +1,13 @@
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
import json_numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"2toINF/X-VLA-WidowX",
|
||||
trust_remote_code=True
|
||||
)
|
||||
model = AutoModel.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
|
||||
|
||||
|
||||
# append 3 random image to a list
|
||||
def make_random_pil_images(num_images=3, H=480, W=640):
|
||||
images = []
|
||||
@@ -20,6 +18,7 @@ def make_random_pil_images(num_images=3, H=480, W=640):
|
||||
images.append(img)
|
||||
return images
|
||||
|
||||
|
||||
# Example:
|
||||
images = make_random_pil_images()
|
||||
language_instruction = "This is a random image"
|
||||
@@ -29,23 +28,27 @@ if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
|
||||
raise ValueError("Processor did not return the expected keys.")
|
||||
|
||||
proprio = torch.randn(1, 20)
|
||||
domain_id = torch.tensor([int(0)], dtype=torch.long)
|
||||
domain_id = torch.tensor([0], dtype=torch.long)
|
||||
|
||||
# Align to model's device/dtype
|
||||
device = model.device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
|
||||
def to_model(t: torch.Tensor) -> torch.Tensor:
|
||||
if not isinstance(t, torch.Tensor):
|
||||
t = torch.as_tensor(t)
|
||||
# cast floats to model dtype, keep integral/bool as-is
|
||||
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
|
||||
|
||||
|
||||
inputs = {k: to_model(v) for k, v in inputs.items()}
|
||||
inputs.update({
|
||||
"proprio": to_model(proprio),
|
||||
"domain_id": domain_id.to(device),
|
||||
})
|
||||
inputs.update(
|
||||
{
|
||||
"proprio": to_model(proprio),
|
||||
"domain_id": domain_id.to(device),
|
||||
}
|
||||
)
|
||||
|
||||
# Inference
|
||||
|
||||
|
||||
Reference in New Issue
Block a user