mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
from lerobot.envs.factory import make_env_config
|
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
|
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
|
|
|
observation_height: int = 360
|
|
observation_width: int = 360
|
|
# create an observation dict
|
|
OBS = {
|
|
f"{OBS_IMAGES}.image1": torch.randn(1, 3, observation_height, observation_width),
|
|
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
|
|
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
|
|
"task": "put the object in the box",
|
|
}
|
|
|
|
|
|
def fake_rgb(H, W):
|
|
img = torch.randint(0, 255, (H, W, 3), dtype=torch.uint8).numpy()
|
|
return img
|
|
|
|
|
|
OBS[f"{OBS_IMAGES}.image1"] = fake_rgb(observation_height, observation_width)
|
|
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
|
|
|
|
# observation = preprocessor(OBS)
|
|
from transformers import AutoProcessor
|
|
|
|
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True)
|
|
inputs = processor([OBS[f"{OBS_IMAGES}.image1"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
|
|
breakpoint()
|
|
|
|
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
|
|
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
|
|
env_cfg = make_env_config("libero", task="libero_spatial")
|
|
policy = make_policy(
|
|
cfg=cfg,
|
|
env_cfg=env_cfg,
|
|
)
|
|
|
|
policy.eval()
|
|
|
|
preprocessor_overrides = {
|
|
"device_processor": {"device": str(cfg.device)},
|
|
}
|
|
|
|
preprocessor, postprocessor = make_pre_post_processors(
|
|
policy_cfg=cfg,
|
|
pretrained_path=cfg.pretrained_path,
|
|
preprocessor_overrides=preprocessor_overrides,
|
|
)
|
|
|
|
observation = preprocessor(OBS)
|
|
action = policy.select_action(observation)
|
|
|
|
target_eef = action[:, :3].to("cpu").numpy()
|
|
target_axis = rotate6d_to_axis_angle(action[:, 3:9].to("cpu").numpy())
|
|
target_act = action[:, 9:10].to("cpu").numpy()
|
|
final_action = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|