mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
more fixes
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.xvla.utils import Rotate6D_to_AxisAngle
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
observation_height: int = 360
|
||||
observation_width: int = 360
|
||||
@@ -13,18 +14,22 @@ observation_width: int = 360
|
||||
OBS = {
|
||||
f"{OBS_IMAGES}.image1": torch.randn(1, 3, observation_height, observation_width),
|
||||
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
|
||||
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
|
||||
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
|
||||
"task": "put the object in the box",
|
||||
}
|
||||
|
||||
|
||||
def fake_rgb(H, W):
|
||||
img = torch.randint(0, 255, (H, W, 3), dtype=torch.uint8).numpy()
|
||||
return img
|
||||
|
||||
|
||||
OBS[f"{OBS_IMAGES}.image1"] = fake_rgb(observation_height, observation_width)
|
||||
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
|
||||
|
||||
# observation = preprocessor(OBS)
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True)
|
||||
inputs = processor([OBS[f"{OBS_IMAGES}.image1"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
|
||||
breakpoint()
|
||||
@@ -40,19 +45,19 @@ policy = make_policy(
|
||||
policy.eval()
|
||||
|
||||
preprocessor_overrides = {
|
||||
"device_processor": {"device": str(cfg.device)},
|
||||
"device_processor": {"device": str(cfg.device)},
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path=cfg.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
policy_cfg=cfg,
|
||||
pretrained_path=cfg.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
|
||||
observation = preprocessor(OBS)
|
||||
action = policy.select_action(observation)
|
||||
|
||||
target_eef = action[:, :3].to("cpu").numpy()
|
||||
target_axis = Rotate6D_to_AxisAngle(action[:, 3:9].to("cpu").numpy())
|
||||
target_axis = rotate6d_to_axis_angle(action[:, 3:9].to("cpu").numpy())
|
||||
target_act = action[:, 9:10].to("cpu").numpy()
|
||||
final_action = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|
||||
|
||||
Reference in New Issue
Block a user