mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
more refactoring
This commit is contained in:
@@ -337,7 +337,7 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processing_xvla import make_xvla_pre_post_processors
|
||||
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
|
||||
|
||||
processors = make_xvla_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
|
||||
@@ -45,17 +45,6 @@ Note that in both examples, the repo/folder should contain at least `config.json
|
||||
|
||||
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# ABSOLUTE PATH TO YOUR PROJECT ROOT
|
||||
PROJECT_ROOT = "/home/jade_choghari/robot/lerobot"
|
||||
|
||||
# Add root to sys.path BEFORE any imports
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
from xvla.models.modeling_xvla import XVLA
|
||||
from xvla.models.processing_xvla import XVLAProcessor
|
||||
import concurrent.futures as cf
|
||||
import json
|
||||
import logging
|
||||
@@ -166,13 +155,6 @@ def rollout(
|
||||
leave=False,
|
||||
)
|
||||
|
||||
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
|
||||
model.eval()
|
||||
model.to("cuda")
|
||||
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2)
|
||||
|
||||
from collections import deque
|
||||
action_queue = deque(maxlen=30)
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done) and step < max_steps:
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
@@ -188,50 +170,15 @@ def rollout(
|
||||
# inputs = processor([observation[f"observation.images.image"], observation[f"observation.images.image2"]], observation["task"])
|
||||
observation = preprocessor(observation)
|
||||
observation["domain_id"] = torch.tensor([int(3)], dtype=torch.long).to("cuda")
|
||||
# inputs_1 = policy._build_model_inputs(observation)
|
||||
# for k in inputs.keys() & inputs_1.keys(): # intersection of keys
|
||||
# a = inputs[k].to("cuda")
|
||||
# b = inputs_1[k].to("cuda")
|
||||
|
||||
# print(f"\n🔎 Key: {k}")
|
||||
|
||||
# # Check shape
|
||||
# print(" shape:", a.shape, b.shape)
|
||||
|
||||
# # Check if close
|
||||
# if torch.allclose(a, b, atol=1e-5, rtol=1e-5):
|
||||
# print(" ✔️ tensors are equal (allclose)")
|
||||
# else:
|
||||
# diff = torch.abs(a - b)
|
||||
# print(" ❌ tensors differ")
|
||||
# print(" max diff:", diff.max().item())
|
||||
# print(" mean diff:", diff.mean().item())
|
||||
# breakpoint()
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation).to("cpu").numpy()
|
||||
# if len(action_queue) == 0:
|
||||
# action = model.generate_actions(**inputs_1, steps=10) # shape (1, 30, 20)
|
||||
# actions_np = action.detach().cpu().numpy()
|
||||
# # add each timestep as (1, 20)
|
||||
# for t in range(actions_np.shape[1]):
|
||||
# act_t = actions_np[:, t, :]
|
||||
# action_queue.append(act_t)
|
||||
# action = action_queue.popleft()
|
||||
# else:
|
||||
# action = action_queue.popleft()
|
||||
# action = postprocessor(action)
|
||||
# breakpoint()
|
||||
# .to("cpu").numpy()
|
||||
target_eef = action[:, :3]
|
||||
target_axis = Rotate6D_to_AxisAngle(action[:, 3:9])
|
||||
target_act = action[:, 9:10]
|
||||
action_numpy = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|
||||
|
||||
# target_eef_1 = action_1[:, :3]
|
||||
# target_axis_1 = Rotate6D_to_AxisAngle(action_1[:, 3:9])
|
||||
# target_act_1 = action_1[:, 9:10]
|
||||
# action_numpy_1 = np.concatenate([target_eef_1, target_axis_1, target_act_1], axis=-1)
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
# action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
Reference in New Issue
Block a user