mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
update files
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
python src/lerobot/processor/migrate_policy_normalization.py \
|
||||
--pretrained-path /raid/jade/models/xvla-libero-og \
|
||||
--output-dir /raid/jade/models/xvla-libero-og-migrated \
|
||||
--branch main
|
||||
@@ -0,0 +1,7 @@
|
||||
lerobot-eval \
|
||||
--policy.path="/raid/jade/models/xvla-libero-og_migrated" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1
|
||||
|
||||
+28
@@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import argparse
|
||||
import os
|
||||
|
||||
def prefix_state_dict(input_path, output_path, prefix="model."):
|
||||
# Load original checkpoint
|
||||
state_dict = st.load_file(input_path)
|
||||
|
||||
print(f"Loaded {len(state_dict)} tensors from {input_path}")
|
||||
|
||||
# Add prefix to every key
|
||||
new_state_dict = {f"{prefix}{k}": v for k, v in state_dict.items()}
|
||||
|
||||
print(f"Writing prefixed checkpoint with {len(new_state_dict)} keys...")
|
||||
st.save_file(new_state_dict, output_path)
|
||||
|
||||
print(f"Saved to {output_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", type=str, required=True, help="Path to model.safetensors")
|
||||
parser.add_argument("--output", type=str, required=True, help="Output prefixed model.safetensors")
|
||||
parser.add_argument("--prefix", type=str, default="model.", help="Prefix to add to each key")
|
||||
args = parser.parse_args()
|
||||
|
||||
prefix_state_dict(args.input, args.output, args.prefix)
|
||||
@@ -29,7 +29,7 @@ from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
from robosuite.utils.transform_utils import quat2axisangle
|
||||
|
||||
from lerobot.policies.xvla.utils import Mat_to_Rotate6D
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
@@ -81,14 +81,14 @@ def get_libero_dummy_action():
|
||||
return [0, 0, 0, 0, 0, 0, -1]
|
||||
|
||||
|
||||
OBS_STATE_DIM = 8
|
||||
OBS_STATE_DIM = 20
|
||||
ACTION_DIM = 7
|
||||
AGENT_POS_LOW = -1000.0
|
||||
AGENT_POS_HIGH = 1000.0
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||
"libero_spatial": 280, # longest training demo has 193 steps
|
||||
"libero_spatial": 800, # longest training demo has 193 steps
|
||||
"libero_object": 280, # longest training demo has 254 steps
|
||||
"libero_goal": 300, # longest training demo has 270 steps
|
||||
"libero_10": 520, # longest training demo has 505 steps
|
||||
@@ -221,6 +221,11 @@ class LiberoEnv(gym.Env):
|
||||
raw_obs["robot0_gripper_qpos"],
|
||||
)
|
||||
)
|
||||
# add new obs for XVLA: jadechoghari
|
||||
robo_ori = Mat_to_Rotate6D(self._env.robots[0].controller.ee_ori_mat)
|
||||
robo_pos = self._env.robots[0].controller.ee_pos
|
||||
proprio = np.concatenate([robo_pos, robo_ori, np.array([0.0])], axis=-1)
|
||||
state = np.concatenate([proprio, np.zeros_like(proprio)], axis=-1)
|
||||
agent_pos = state
|
||||
if self.obs_type == "pixels":
|
||||
return {"pixels": images.copy()}
|
||||
|
||||
@@ -428,8 +428,6 @@ def make_policy(
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(**kwargs)
|
||||
kwargs["pretrained_name_or_path"] = "/fsx/jade_choghari/.cache/huggingface/model/xvla-libero"
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, torch.nn.Module)
|
||||
|
||||
@@ -28,6 +28,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
from .configuration_florence2 import Florence2Config
|
||||
from .configuration_florence2 import Florence2VisionConfig
|
||||
from .configuration_florence2 import Florence2LanguageConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("xvla")
|
||||
@@ -56,7 +58,7 @@ class XVLAConfig(PreTrainedConfig):
|
||||
)
|
||||
|
||||
# Florence2 backbone and tokenizer configuration
|
||||
florence_config: dict[str, Any] | Florence2Config = field(default_factory=dict)
|
||||
florence_config: dict[str, Any] = field(default_factory=dict)
|
||||
tokenizer_name: str = "facebook/bart-large"
|
||||
tokenizer_max_length: int = 64
|
||||
tokenizer_padding_side: str = "right"
|
||||
@@ -81,7 +83,7 @@ class XVLAConfig(PreTrainedConfig):
|
||||
domain_feature_key: str | None = None
|
||||
|
||||
# Vision preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] | None = (518, 518)
|
||||
resize_imgs_with_padding: tuple[int, int] | None = None
|
||||
num_image_views: int | None = None
|
||||
empty_cameras: int = 0
|
||||
|
||||
@@ -107,8 +109,6 @@ class XVLAConfig(PreTrainedConfig):
|
||||
raise ValueError(
|
||||
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
|
||||
)
|
||||
if isinstance(self.florence_config, Florence2Config):
|
||||
self.florence_config = self.florence_config.to_dict()
|
||||
if self.num_image_views is not None and self.num_image_views <= 0:
|
||||
raise ValueError("`num_image_views` must be > 0 when specified.")
|
||||
self._florence_config_obj: Florence2Config | None = None
|
||||
@@ -118,56 +118,55 @@ class XVLAConfig(PreTrainedConfig):
|
||||
Build (and cache) the Florence2 transformer config that should back the VLM.
|
||||
"""
|
||||
if self._florence_config_obj is None:
|
||||
if isinstance(self.florence_config, Florence2Config):
|
||||
self._florence_config_obj = self.florence_config
|
||||
else:
|
||||
# TODO: jadechoghari: provide default way, and do not hardcode
|
||||
# Ensure vision_config and text_config are provided with defaults if not specified
|
||||
config_dict = dict(self.florence_config)
|
||||
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
|
||||
# Provide default vision config
|
||||
config_dict["vision_config"] = {
|
||||
"model_type": "davit",
|
||||
"drop_path_rate": 0.1,
|
||||
"patch_size": [7, 3, 3, 3],
|
||||
"patch_stride": [4, 2, 2, 2],
|
||||
"patch_padding": [3, 1, 1, 1],
|
||||
"patch_prenorm": [False, True, True, True],
|
||||
"enable_checkpoint": False,
|
||||
"dim_embed": [256, 512, 1024, 2048],
|
||||
"num_heads": [8, 16, 32, 64],
|
||||
"num_groups": [8, 16, 32, 64],
|
||||
"depths": [1, 1, 9, 1],
|
||||
"window_size": 12,
|
||||
"projection_dim": 1024,
|
||||
"visual_temporal_embedding": {
|
||||
"type": "COSINE",
|
||||
"max_temporal_embeddings": 100
|
||||
},
|
||||
"image_pos_embed": {
|
||||
"type": "learned_abs_2d",
|
||||
"max_pos_embeddings": 50
|
||||
},
|
||||
"image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
|
||||
}
|
||||
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
||||
# Provide default text config
|
||||
config_dict["text_config"] = {
|
||||
"vocab_size": 51289,
|
||||
"activation_dropout": 0.1,
|
||||
"activation_function": "gelu",
|
||||
"attention_dropout": 0.1,
|
||||
"d_model": 1024,
|
||||
"decoder_attention_heads": 16,
|
||||
"decoder_layers": 12,
|
||||
"encoder_attention_heads": 16,
|
||||
"encoder_layers": 12,
|
||||
"dropout": 0.1,
|
||||
"max_position_embeddings": 4096,
|
||||
"num_hidden_layers": 12,
|
||||
"num_beams": 3
|
||||
}
|
||||
self._florence_config_obj = Florence2Config(**config_dict)
|
||||
# TODO: jadechoghari: provide default way, and do not hardcode
|
||||
# Ensure vision_config and text_config are provided with defaults if not specified
|
||||
config_dict = dict(self.florence_config)
|
||||
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
|
||||
raise ValueError("vision_config is required")
|
||||
# # Provide default vision config
|
||||
# config_dict["vision_config"] = {
|
||||
# "model_type": "davit",
|
||||
# "drop_path_rate": 0.1,
|
||||
# "patch_size": [7, 3, 3, 3],
|
||||
# "patch_stride": [4, 2, 2, 2],
|
||||
# "patch_padding": [3, 1, 1, 1],
|
||||
# "patch_prenorm": [False, True, True, True],
|
||||
# "enable_checkpoint": False,
|
||||
# "dim_embed": [256, 512, 1024, 2048],
|
||||
# "num_heads": [8, 16, 32, 64],
|
||||
# "num_groups": [8, 16, 32, 64],
|
||||
# "depths": [1, 1, 9, 1],
|
||||
# "window_size": 12,
|
||||
# "projection_dim": 1024,
|
||||
# "visual_temporal_embedding": {
|
||||
# "type": "COSINE",
|
||||
# "max_temporal_embeddings": 100
|
||||
# },
|
||||
# "image_pos_embed": {
|
||||
# "type": "learned_abs_2d",
|
||||
# "max_pos_embeddings": 50
|
||||
# },
|
||||
# "image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
|
||||
# }
|
||||
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
||||
raise ValueError("text_config is required")
|
||||
# # Provide default text config
|
||||
# config_dict["text_config"] = {
|
||||
# "vocab_size": 51289,
|
||||
# "activation_dropout": 0.1,
|
||||
# "activation_function": "gelu",
|
||||
# "attention_dropout": 0.1,
|
||||
# "d_model": 1024,
|
||||
# "decoder_attention_heads": 16,
|
||||
# "decoder_layers": 12,
|
||||
# "encoder_attention_heads": 16,
|
||||
# "encoder_layers": 12,
|
||||
# "dropout": 0.1,
|
||||
# "max_position_embeddings": 4096,
|
||||
# "num_hidden_layers": 12,
|
||||
# "num_beams": 3
|
||||
# }
|
||||
self._florence_config_obj = Florence2Config(**config_dict)
|
||||
return self._florence_config_obj
|
||||
|
||||
def validate_features(self) -> None:
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
import json_numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"2toINF/X-VLA-WidowX",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
|
||||
# append 3 random image to a list
|
||||
def make_random_pil_images(num_images=3, H=480, W=640):
|
||||
images = []
|
||||
for _ in range(num_images):
|
||||
# Random RGB image
|
||||
arr = np.random.randint(0, 256, (H, W, 3), dtype=np.uint8)
|
||||
img = Image.fromarray(arr)
|
||||
images.append(img)
|
||||
return images
|
||||
|
||||
# Example:
|
||||
images = make_random_pil_images()
|
||||
language_instruction = "This is a random image"
|
||||
# Multimodal preprocessing by processor
|
||||
inputs = processor(images, language_instruction)
|
||||
if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
|
||||
raise ValueError("Processor did not return the expected keys.")
|
||||
|
||||
proprio = torch.randn(1, 20)
|
||||
domain_id = torch.tensor([int(0)], dtype=torch.long)
|
||||
|
||||
# Align to model's device/dtype
|
||||
device = model.device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
def to_model(t: torch.Tensor) -> torch.Tensor:
|
||||
if not isinstance(t, torch.Tensor):
|
||||
t = torch.as_tensor(t)
|
||||
# cast floats to model dtype, keep integral/bool as-is
|
||||
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
|
||||
|
||||
inputs = {k: to_model(v) for k, v in inputs.items()}
|
||||
inputs.update({
|
||||
"proprio": to_model(proprio),
|
||||
"domain_id": domain_id.to(device),
|
||||
})
|
||||
|
||||
# Inference
|
||||
|
||||
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||
|
||||
breakpoint()
|
||||
@@ -0,0 +1,59 @@
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.xvla.utils import Rotate6D_to_AxisAngle
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
observation_height: int = 360
|
||||
observation_width: int = 360
|
||||
# create an observation dict
|
||||
OBS = {
|
||||
f"{OBS_IMAGES}.image1": torch.randn(1, 3, observation_height, observation_width),
|
||||
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
|
||||
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
|
||||
"task": "put the object in the box",
|
||||
}
|
||||
def fake_rgb(H, W):
|
||||
img = torch.randint(0, 255, (H, W, 3), dtype=torch.uint8).numpy()
|
||||
return img
|
||||
|
||||
OBS[f"{OBS_IMAGES}.image1"] = fake_rgb(observation_height, observation_width)
|
||||
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
|
||||
|
||||
# observation = preprocessor(OBS)
|
||||
from transformers import AutoProcessor
|
||||
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True)
|
||||
inputs = processor([OBS[f"{OBS_IMAGES}.image1"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
|
||||
breakpoint()
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
|
||||
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
|
||||
env_cfg = make_env_config("libero", task="libero_spatial")
|
||||
policy = make_policy(
|
||||
cfg=cfg,
|
||||
env_cfg=env_cfg,
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
|
||||
preprocessor_overrides = {
|
||||
"device_processor": {"device": str(cfg.device)},
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path=cfg.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
|
||||
observation = preprocessor(OBS)
|
||||
action = policy.select_action(observation)
|
||||
|
||||
target_eef = action[:, :3].to("cpu").numpy()
|
||||
target_axis = Rotate6D_to_AxisAngle(action[:, 3:9].to("cpu").numpy())
|
||||
target_act = action[:, 9:10].to("cpu").numpy()
|
||||
final_action = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|
||||
breakpoint()
|
||||
@@ -94,10 +94,9 @@ class XVLAModel(nn.Module):
|
||||
batch_size, num_views = pixel_values.shape[:2]
|
||||
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
|
||||
flat_images = pixel_values.flatten(0, 1)
|
||||
|
||||
#TODO: jadechoghari: remove this resizing logic, and provide a way in training to do this
|
||||
target_size = (224, 224)
|
||||
flat_images = F.interpolate(flat_images, size=target_size, mode="bilinear", align_corners=False)
|
||||
# target_size = (224, 224)
|
||||
# flat_images = F.interpolate(flat_images, size=target_size, mode="bilinear", align_corners=False)
|
||||
|
||||
|
||||
num_valid = int(flat_mask.sum().item())
|
||||
@@ -197,7 +196,6 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
def __init__(self, config: XVLAConfig):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
|
||||
florence_config = config.get_florence_config()
|
||||
proprio_dim = config.max_state_dim if config.use_proprio else 0
|
||||
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
|
||||
@@ -381,7 +379,6 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
|
||||
# --- Step 2: Locate model.safetensors ---
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
@@ -422,9 +419,13 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
k: v for k, v in new_state_dict.items()
|
||||
if k not in keys_to_skip
|
||||
}
|
||||
|
||||
# ---- ADD THIS: Fix shared embeddings ----
|
||||
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
||||
shared_key = "model.vlm.language_model.model.shared.weight"
|
||||
if encoder_key in state_dict:
|
||||
state_dict[shared_key] = state_dict[encoder_key]
|
||||
# --- Step 5: Load into instance ---
|
||||
missing, unexpected = instance.load_state_dict(new_state_dict, strict=False)
|
||||
missing, unexpected = instance.load_state_dict(state_dict, strict=True)
|
||||
print("✅ Loaded XVLA checkpoint with modified keys.")
|
||||
if missing:
|
||||
print(f"Missing keys: {missing}")
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import numpy as np
|
||||
import robosuite.utils.transform_utils as T
|
||||
def Rotate6D_to_AxisAngle(r6d):
|
||||
"""
|
||||
r6d: np.ndarray, shape (N, 6)
|
||||
return: np.ndarray, shape (N, 3), axis-angle vectors
|
||||
"""
|
||||
flag = 0
|
||||
if len(r6d.shape) == 1:
|
||||
r6d = r6d[None, ...]
|
||||
flag = 1
|
||||
|
||||
a1 = r6d[:, 0:3]
|
||||
a2 = r6d[:, 3:6]
|
||||
|
||||
# b1
|
||||
b1 = a1 / (np.linalg.norm(a1, axis=-1, keepdims=True) + 1e-6)
|
||||
|
||||
# b2
|
||||
dot_prod = np.sum(b1 * a2, axis=-1, keepdims=True)
|
||||
b2_orth = a2 - dot_prod * b1
|
||||
b2 = b2_orth / (np.linalg.norm(b2_orth, axis=-1, keepdims=True) + 1e-6)
|
||||
|
||||
# b3
|
||||
b3 = np.cross(b1, b2, axis=-1)
|
||||
|
||||
R = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
|
||||
|
||||
axis_angle_list = []
|
||||
for i in range(R.shape[0]):
|
||||
quat = T.mat2quat(R[i])
|
||||
axis_angle = T.quat2axisangle(quat)
|
||||
axis_angle_list.append(axis_angle)
|
||||
|
||||
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
|
||||
|
||||
if flag == 1:
|
||||
axis_angle_array = axis_angle_array[0]
|
||||
|
||||
return axis_angle_array
|
||||
|
||||
def Mat_to_Rotate6D(abs_action):
|
||||
if len(abs_action.shape) == 2:
|
||||
return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1)
|
||||
elif len(abs_action.shape) == 3:
|
||||
return np.concatenate([abs_action[:, :3, 0], abs_action[:, :3, 1]], axis=-1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -325,6 +325,7 @@ def load_state_dict_with_missing_key_handling(
|
||||
Returns:
|
||||
List of problematic missing keys that weren't in the whitelist.
|
||||
"""
|
||||
state_dict['model.vlm.language_model.model.encoder.embed_tokens.weight'] = state_dict['model.vlm.language_model.model.shared.weight'].clone()
|
||||
# Load the cleaned state dict with strict=False to capture missing/unexpected keys
|
||||
load_result = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
@@ -321,7 +321,6 @@ class _NormalizationMixin:
|
||||
self.to(device=tensor.device, dtype=tensor.dtype)
|
||||
|
||||
stats = self._tensor_stats[key]
|
||||
|
||||
if norm_mode == NormalizationMode.MEAN_STD:
|
||||
mean = stats.get("mean", None)
|
||||
std = stats.get("std", None)
|
||||
|
||||
@@ -89,7 +89,7 @@ from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
)
|
||||
|
||||
from lerobot.policies.xvla.utils import Rotate6D_to_AxisAngle
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
@@ -155,6 +155,17 @@ def rollout(
|
||||
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
||||
leave=False,
|
||||
)
|
||||
from transformers import AutoProcessor, AutoModel
|
||||
model = AutoModel.from_pretrained(
|
||||
"2toINF/X-VLA-WidowX",
|
||||
trust_remote_code=True,
|
||||
device="cuda"
|
||||
)
|
||||
model.to("cuda")
|
||||
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True)
|
||||
|
||||
from collections import deque
|
||||
action_queue = deque(maxlen=30)
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done) and step < max_steps:
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
@@ -165,13 +176,54 @@ def rollout(
|
||||
# Infer "task" from attributes of environments.
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
inputs = processor([observation[f"observation.images.image"], observation[f"observation.images.image2"]], observation["task"], do_rescale=False)
|
||||
observation = preprocessor(observation)
|
||||
observation["observation.images.image"] = inputs["image_input"][:, 0, ...].to("cuda")
|
||||
observation["observation.images.image2"] = inputs["image_input"][:, 1, ...].to("cuda")
|
||||
observation["observation.language.tokens"] = inputs["input_ids"].to("cuda")
|
||||
|
||||
# (Pdb) inputs.keys()
|
||||
# dict_keys(['input_ids', 'image_input', 'image_mask', 'proprio', 'domain_id'])
|
||||
# image_input should be torch.Size([1, 2, 3, 224, 224])
|
||||
img0 = observation["observation.images.image"] # [1, 3, 224, 224]
|
||||
img1 = observation["observation.images.image2"] # [1, 3, 224, 224]
|
||||
img0 = img0.unsqueeze(1) # [1, 1, 3, 224, 224]
|
||||
img1 = img1.unsqueeze(1) # [1, 1, 3, 224, 224]
|
||||
obs = {}
|
||||
obs['input_ids'] = observation["observation.language.tokens"].to("cuda")
|
||||
obs['image_input'] = torch.cat([img0, img1], dim=1).to("cuda")
|
||||
obs['domain_id'] = torch.tensor([int(3)], dtype=torch.long).to("cuda")
|
||||
obs['proprio'] = observation["observation.state"].to("cuda")
|
||||
obs['image_mask'] = inputs["image_mask"].to("cuda")
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action = postprocessor(action)
|
||||
action_1 = policy.select_action(observation).to("cpu").numpy()
|
||||
if len(action_queue) == 0:
|
||||
action = model.generate_actions(**obs, steps=10) # shape (1, 30, 20)
|
||||
actions_np = action.detach().cpu().numpy()
|
||||
# add each timestep as (1, 20)
|
||||
for t in range(actions_np.shape[1]):
|
||||
act_t = actions_np[:, t, :]
|
||||
action_queue.append(act_t)
|
||||
action = action_queue.popleft()
|
||||
else:
|
||||
action = action_queue.popleft()
|
||||
# action = postprocessor(action)
|
||||
# breakpoint()
|
||||
# .to("cpu").numpy()
|
||||
target_eef = action[:, :3]
|
||||
target_axis = Rotate6D_to_AxisAngle(action[:, 3:9])
|
||||
target_act = action[:, 9:10]
|
||||
action_numpy = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|
||||
|
||||
target_eef_1 = action_1[:, :3]
|
||||
target_axis_1 = Rotate6D_to_AxisAngle(action_1[:, 3:9])
|
||||
target_act_1 = action_1[:, 9:10]
|
||||
action_numpy_1 = np.concatenate([target_eef_1, target_axis_1, target_act_1], axis=-1)
|
||||
breakpoint()
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
# action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
# Apply the next action.
|
||||
@@ -497,7 +549,6 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Making policy.")
|
||||
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
|
||||
Reference in New Issue
Block a user