Files
lerobot/examples/tutorial/xvla/inference_pipe.py
T
Jade Choghari 5277a9909d more fixes
2025-11-17 14:03:15 +01:00

64 lines
2.1 KiB
Python

import numpy as np
import torch
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
observation_height: int = 360
observation_width: int = 360
# create an observation dict
OBS = {
f"{OBS_IMAGES}.image1": torch.randn(1, 3, observation_height, observation_width),
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
"task": "put the object in the box",
}
def fake_rgb(H, W):
img = torch.randint(0, 255, (H, W, 3), dtype=torch.uint8).numpy()
return img
OBS[f"{OBS_IMAGES}.image1"] = fake_rgb(observation_height, observation_width)
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
# observation = preprocessor(OBS)
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True)
inputs = processor([OBS[f"{OBS_IMAGES}.image1"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
breakpoint()
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
env_cfg = make_env_config("libero", task="libero_spatial")
policy = make_policy(
cfg=cfg,
env_cfg=env_cfg,
)
policy.eval()
preprocessor_overrides = {
"device_processor": {"device": str(cfg.device)},
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg,
pretrained_path=cfg.pretrained_path,
preprocessor_overrides=preprocessor_overrides,
)
observation = preprocessor(OBS)
action = policy.select_action(observation)
target_eef = action[:, :3].to("cpu").numpy()
target_axis = rotate6d_to_axis_angle(action[:, 3:9].to("cpu").numpy())
target_act = action[:, 9:10].to("cpu").numpy()
final_action = np.concatenate([target_eef, target_axis, target_act], axis=-1)