more fixes

This commit is contained in:
Jade Choghari
2025-11-17 14:03:15 +01:00
parent fb6f59e074
commit 5277a9909d
16 changed files with 215 additions and 176 deletions
-1
View File
@@ -6,4 +6,3 @@ lerobot-eval \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--seed=142 --seed=142
+12 -9
View File
@@ -1,15 +1,13 @@
from transformers import AutoModel, AutoProcessor
import json_numpy
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from transformers import AutoModel, AutoProcessor
model = AutoModel.from_pretrained( model = AutoModel.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
"2toINF/X-VLA-WidowX",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True) processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
# append 3 random image to a list # append 3 random image to a list
def make_random_pil_images(num_images=3, H=480, W=640): def make_random_pil_images(num_images=3, H=480, W=640):
images = [] images = []
@@ -20,6 +18,7 @@ def make_random_pil_images(num_images=3, H=480, W=640):
images.append(img) images.append(img)
return images return images
# Example: # Example:
images = make_random_pil_images() images = make_random_pil_images()
language_instruction = "This is a random image" language_instruction = "This is a random image"
@@ -29,23 +28,27 @@ if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
raise ValueError("Processor did not return the expected keys.") raise ValueError("Processor did not return the expected keys.")
proprio = torch.randn(1, 20) proprio = torch.randn(1, 20)
domain_id = torch.tensor([int(0)], dtype=torch.long) domain_id = torch.tensor([0], dtype=torch.long)
# Align to model's device/dtype # Align to model's device/dtype
device = model.device device = model.device
dtype = next(model.parameters()).dtype dtype = next(model.parameters()).dtype
def to_model(t: torch.Tensor) -> torch.Tensor: def to_model(t: torch.Tensor) -> torch.Tensor:
if not isinstance(t, torch.Tensor): if not isinstance(t, torch.Tensor):
t = torch.as_tensor(t) t = torch.as_tensor(t)
# cast floats to model dtype, keep integral/bool as-is # cast floats to model dtype, keep integral/bool as-is
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device) return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
inputs = {k: to_model(v) for k, v in inputs.items()} inputs = {k: to_model(v) for k, v in inputs.items()}
inputs.update({ inputs.update(
{
"proprio": to_model(proprio), "proprio": to_model(proprio),
"domain_id": domain_id.to(device), "domain_id": domain_id.to(device),
}) }
)
# Inference # Inference
+10 -5
View File
@@ -1,11 +1,12 @@
from lerobot.policies.factory import make_policy, make_pre_post_processors import numpy as np
import torch
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig # from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.factory import make_env_config from lerobot.envs.factory import make_env_config
from lerobot.policies.xvla.utils import Rotate6D_to_AxisAngle from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
import torch
import numpy as np
observation_height: int = 360 observation_height: int = 360
observation_width: int = 360 observation_width: int = 360
@@ -16,15 +17,19 @@ OBS = {
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
"task": "put the object in the box", "task": "put the object in the box",
} }
def fake_rgb(H, W): def fake_rgb(H, W):
img = torch.randint(0, 255, (H, W, 3), dtype=torch.uint8).numpy() img = torch.randint(0, 255, (H, W, 3), dtype=torch.uint8).numpy()
return img return img
OBS[f"{OBS_IMAGES}.image1"] = fake_rgb(observation_height, observation_width) OBS[f"{OBS_IMAGES}.image1"] = fake_rgb(observation_height, observation_width)
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width) OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
# observation = preprocessor(OBS) # observation = preprocessor(OBS)
from transformers import AutoProcessor from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True) processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True)
inputs = processor([OBS[f"{OBS_IMAGES}.image1"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"]) inputs = processor([OBS[f"{OBS_IMAGES}.image1"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
breakpoint() breakpoint()
@@ -53,6 +58,6 @@ observation = preprocessor(OBS)
action = policy.select_action(observation) action = policy.select_action(observation)
target_eef = action[:, :3].to("cpu").numpy() target_eef = action[:, :3].to("cpu").numpy()
target_axis = Rotate6D_to_AxisAngle(action[:, 3:9].to("cpu").numpy()) target_axis = rotate6d_to_axis_angle(action[:, 3:9].to("cpu").numpy())
target_act = action[:, 9:10].to("cpu").numpy() target_act = action[:, 9:10].to("cpu").numpy()
final_action = np.concatenate([target_eef, target_axis, target_act], axis=-1) final_action = np.concatenate([target_eef, target_axis, target_act], axis=-1)
+37 -18
View File
@@ -1,6 +1,8 @@
import os
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy, make_policy_config from lerobot.policies.factory import make_policy, make_policy_config
import os
cfg = make_policy_config("xvla") cfg = make_policy_config("xvla")
dataset_id = "lerobot/svla_so101_pickplace" dataset_id = "lerobot/svla_so101_pickplace"
@@ -16,7 +18,9 @@ for name, param in policy.state_dict().items():
import safetensors.torch import safetensors.torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
cache_dir = snapshot_download(repo_id="2toINF/X-VLA-Libero", repo_type="model", cache_dir="/fsx/jade_choghari/.cache/huggingface/model") cache_dir = snapshot_download(
repo_id="2toINF/X-VLA-Libero", repo_type="model", cache_dir="/fsx/jade_choghari/.cache/huggingface/model"
)
state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors")) state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors"))
# policy.load_state_dict(state_dict) # policy.load_state_dict(state_dict)
# 3. Add "model." prefix to every key # 3. Add "model." prefix to every key
@@ -36,16 +40,18 @@ print()
print("unexpected keys:", unexpected) print("unexpected keys:", unexpected)
import random
import numpy as np
import torch
from xvla.models.modeling_xvla import XVLA
from lerobot.policies.factory import make_policy, make_pre_post_processors
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig # from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.factory import make_env_config from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from xvla.models.modeling_xvla import XVLA
import torch
import numpy as np
import random
torch.manual_seed(42) torch.manual_seed(42)
random.seed(42) random.seed(42)
np.random.seed(42) np.random.seed(42)
@@ -61,6 +67,8 @@ OBS = {
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
def fake_rgb(H, W): def fake_rgb(H, W):
arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8) arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
@@ -101,11 +109,13 @@ from xvla.models.processing_xvla import XVLAProcessor
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2) processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2)
inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"]) inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
domain_id = torch.tensor([int(3)], dtype=torch.long) domain_id = torch.tensor([3], dtype=torch.long)
inputs.update({ inputs.update(
{
"proprio": OBS[OBS_STATE].to("cuda"), "proprio": OBS[OBS_STATE].to("cuda"),
"domain_id": domain_id.to("cuda"), "domain_id": domain_id.to("cuda"),
}) }
)
for k in inputs.keys() & inputs_1.keys(): # intersection of keys for k in inputs.keys() & inputs_1.keys(): # intersection of keys
@@ -145,16 +155,19 @@ print("max diff:", np.max(np.abs(action - action_1)))
print("mean diff:", np.mean(np.abs(action - action_1))) print("mean diff:", np.mean(np.abs(action - action_1)))
from xvla.models.processor_xvla import XVLAProcessor
from xvla.models.modeling_xvla import XVLA
from xvla.models.configuration_xvla import XVLAConfig
import torch
import random import random
import numpy as np import numpy as np
import torch
from PIL import Image from PIL import Image
from lerobot.policies.factory import make_policy from xvla.models.configuration_xvla import XVLAConfig
from xvla.models.modeling_xvla import XVLA
from xvla.models.processor_xvla import XVLAProcessor
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.factory import make_env_config from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy
cfg = XVLAConfig.from_pretrained("/raid/jade/models/xvla-libero") cfg = XVLAConfig.from_pretrained("/raid/jade/models/xvla-libero")
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero") model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
model.eval() model.eval()
@@ -166,6 +179,7 @@ torch.manual_seed(42)
random.seed(42) random.seed(42)
np.random.seed(42) np.random.seed(42)
def make_random_pil_images(num_images=3, H=480, W=640): def make_random_pil_images(num_images=3, H=480, W=640):
images = [] images = []
for _ in range(num_images): for _ in range(num_images):
@@ -175,6 +189,7 @@ def make_random_pil_images(num_images=3, H=480, W=640):
images.append(img) images.append(img)
return images return images
# Example: # Example:
images = make_random_pil_images() images = make_random_pil_images()
language_instruction = "This is a random image" language_instruction = "This is a random image"
@@ -184,23 +199,27 @@ if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
raise ValueError("Processor did not return the expected keys.") raise ValueError("Processor did not return the expected keys.")
proprio = torch.randn(1, 20) proprio = torch.randn(1, 20)
domain_id = torch.tensor([int(0)], dtype=torch.long) domain_id = torch.tensor([0], dtype=torch.long)
# Align to model's device/dtype # Align to model's device/dtype
device = model.device device = model.device
dtype = next(model.parameters()).dtype dtype = next(model.parameters()).dtype
def to_model(t: torch.Tensor) -> torch.Tensor: def to_model(t: torch.Tensor) -> torch.Tensor:
if not isinstance(t, torch.Tensor): if not isinstance(t, torch.Tensor):
t = torch.as_tensor(t) t = torch.as_tensor(t)
# cast floats to model dtype, keep integral/bool as-is # cast floats to model dtype, keep integral/bool as-is
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device) return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
inputs = {k: to_model(v) for k, v in inputs.items()} inputs = {k: to_model(v) for k, v in inputs.items()}
inputs.update({ inputs.update(
{
"proprio": to_model(proprio), "proprio": to_model(proprio),
"domain_id": domain_id.to(device), "domain_id": domain_id.to(device),
}) }
)
# Inference # Inference
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy() action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
+8 -4
View File
@@ -29,7 +29,9 @@ from gymnasium import spaces
from libero.libero import benchmark, get_libero_path from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.transform_utils import quat2axisangle from robosuite.utils.transform_utils import quat2axisangle
from lerobot.policies.xvla.utils import Mat_to_Rotate6D
from lerobot.policies.xvla.utils import mat_to_rotate6d
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings.""" """Normalize camera_name into a non-empty list of strings."""
@@ -227,13 +229,15 @@ class LiberoEnv(gym.Env):
# TODO: jadechoghari, this is an ugly quick workaround for XVLA states. # TODO: jadechoghari, this is an ugly quick workaround for XVLA states.
# we will open a new PR to handle this in a preprocessor. # we will open a new PR to handle this in a preprocessor.
elif self.action_type == "abs": elif self.action_type == "abs":
robo_ori = Mat_to_Rotate6D(self._env.robots[0].controller.ee_ori_mat) robo_ori = mat_to_rotate6d(self._env.robots[0].controller.ee_ori_mat)
robo_pos = self._env.robots[0].controller.ee_pos robo_pos = self._env.robots[0].controller.ee_pos
proprio = np.concatenate([robo_pos, robo_ori, np.array([0.0])], axis=-1) proprio = np.concatenate([robo_pos, robo_ori, np.array([0.0])], axis=-1)
state = np.concatenate([proprio, np.zeros_like(proprio)], axis=-1) state = np.concatenate([proprio, np.zeros_like(proprio)], axis=-1)
else: else:
raise NotImplementedError(f"The action type '{self.action_type}' is not supported in LiberoEnv. " raise NotImplementedError(
"Please switch to an action type (e.g. 'rel', 'abs').") f"The action type '{self.action_type}' is not supported in LiberoEnv. "
"Please switch to an action type (e.g. 'rel', 'abs')."
)
agent_pos = state agent_pos = state
if self.obs_type == "pixels": if self.obs_type == "pixels":
return {"pixels": images.copy()} return {"pixels": images.copy()}
+2 -2
View File
@@ -1,6 +1,6 @@
from lerobot.policies.xvla.processor_xvla import ( from lerobot.policies.xvla.processor_xvla import (
make_xvla_pre_post_processors,
XVLAImageScaleProcessorStep,
XVLAAddDomainIdProcessorStep, XVLAAddDomainIdProcessorStep,
XVLAImageScaleProcessorStep,
XVLARotation6DToAxisAngleProcessorStep, XVLARotation6DToAxisAngleProcessorStep,
make_xvla_pre_post_processors,
) )
@@ -187,7 +187,7 @@ class Florence2LanguageConfig(PretrainedConfig):
>>> configuration = Florence2LanguageConfig() >>> configuration = Florence2LanguageConfig()
>>> # Initializing a model (with random weights) >>> # Initializing a model (with random weights)
>>> model = Florence2LangaugeModel(configuration) >>> model = Florence2LanguageModel(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
@@ -28,8 +28,6 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES from lerobot.utils.constants import OBS_IMAGES
from .configuration_florence2 import Florence2Config from .configuration_florence2 import Florence2Config
from .configuration_florence2 import Florence2VisionConfig
from .configuration_florence2 import Florence2LanguageConfig
@PreTrainedConfig.register_subclass("xvla") @PreTrainedConfig.register_subclass("xvla")
@@ -496,8 +496,8 @@ class DaViT(nn.Module):
drop_path_rate (float): Stochastic depth rate. Default: 0.1. drop_path_rate (float): Stochastic depth rate. Default: 0.1.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
enable_checkpoint (bool): If True, enable checkpointing. Default: False. enable_checkpoint (bool): If True, enable checkpointing. Default: False.
conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True. conv_at_attn (bool): If True, perform depthwise convolution before attention layer. Default: True.
conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True. conv_at_ffn (bool): If True, perform depthwise convolution before ffn layer. Default: True.
""" """
def __init__( def __init__(
@@ -892,7 +892,7 @@ class Florence2FlashAttention2(Florence2Attention):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+7 -6
View File
@@ -18,6 +18,7 @@
from __future__ import annotations from __future__ import annotations
import os
from collections import deque from collections import deque
import torch import torch
@@ -34,7 +35,6 @@ from .configuration_xvla import XVLAConfig
from .modeling_florence2 import Florence2ForConditionalGeneration from .modeling_florence2 import Florence2ForConditionalGeneration
from .transformer import SoftPromptedTransformer from .transformer import SoftPromptedTransformer
import os
class XVLAModel(nn.Module): class XVLAModel(nn.Module):
""" """
@@ -98,7 +98,6 @@ class XVLAModel(nn.Module):
# target_size = (224, 224) # target_size = (224, 224)
# flat_images = F.interpolate(flat_images, size=target_size, mode="bilinear", align_corners=False) # flat_images = F.interpolate(flat_images, size=target_size, mode="bilinear", align_corners=False)
num_valid = int(flat_mask.sum().item()) num_valid = int(flat_mask.sum().item())
if num_valid == 0: if num_valid == 0:
raise ValueError("At least one image view must be valid per batch.") raise ValueError("At least one image view must be valid per batch.")
@@ -347,7 +346,7 @@ class XVLAPolicy(PreTrainedPolicy):
cls, cls,
pretrained_name_or_path: str | Path, pretrained_name_or_path: str | Path,
*, *,
config: "PreTrainedConfig" | None = None, config: PreTrainedConfig | None = None,
force_download: bool = False, force_download: bool = False,
resume_download: bool | None = None, resume_download: bool | None = None,
proxies: dict | None = None, proxies: dict | None = None,
@@ -364,6 +363,7 @@ class XVLAPolicy(PreTrainedPolicy):
- skip list for layers that should remain randomly initialized - skip list for layers that should remain randomly initialized
""" """
import safetensors.torch import safetensors.torch
# --- Step 1: Load config --- # --- Step 1: Load config ---
if config is None: if config is None:
config = PreTrainedConfig.from_pretrained( config = PreTrainedConfig.from_pretrained(
@@ -386,6 +386,9 @@ class XVLAPolicy(PreTrainedPolicy):
model_file = os.path.join(model_id, "model.safetensors") model_file = os.path.join(model_id, "model.safetensors")
else: else:
try: try:
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
model_file = hf_hub_download( model_file = hf_hub_download(
repo_id=model_id, repo_id=model_id,
filename="model.safetensors", filename="model.safetensors",
@@ -398,9 +401,7 @@ class XVLAPolicy(PreTrainedPolicy):
local_files_only=local_files_only, local_files_only=local_files_only,
) )
except HfHubHTTPError as e: except HfHubHTTPError as e:
raise FileNotFoundError( raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
f"model.safetensors not found on the Hub at {model_id}"
) from e
print(f"Loading checkpoint from {model_file}") print(f"Loading checkpoint from {model_file}")
state_dict = safetensors.torch.load_file(model_file) state_dict = safetensors.torch.load_file(model_file)
+3 -2
View File
@@ -21,7 +21,7 @@ import numpy as np
import torch import torch
from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.utils import Rotate6D_to_AxisAngle from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
from lerobot.processor import ( from lerobot.processor import (
AddBatchDimensionProcessorStep, AddBatchDimensionProcessorStep,
DeviceProcessorStep, DeviceProcessorStep,
@@ -38,6 +38,7 @@ from lerobot.processor.converters import policy_action_to_transition, transition
from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_xvla_pre_post_processors( def make_xvla_pre_post_processors(
config: XVLAConfig, config: XVLAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
@@ -231,7 +232,7 @@ class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
target_act = action_np[:, 9:10] # (B, 1) target_act = action_np[:, 9:10] # (B, 1)
# Convert 6D rotation to axis-angle # Convert 6D rotation to axis-angle
target_axis = Rotate6D_to_AxisAngle(rotation_6d) # (B, 3) target_axis = rotate6d_to_axis_angle(rotation_6d) # (B, 3)
# Concatenate: [eef (3), axis_angle (3), gripper (1)] = 7D # Concatenate: [eef (3), axis_angle (3), gripper (1)] = 7D
action_np = np.concatenate([target_eef, target_axis, target_act], axis=-1) action_np = np.concatenate([target_eef, target_axis, target_act], axis=-1)
+5 -2
View File
@@ -1,6 +1,8 @@
import numpy as np import numpy as np
import robosuite.utils.transform_utils as T import robosuite.utils.transform_utils as T
def Rotate6D_to_AxisAngle(r6d):
def rotate6d_to_axis_angle(r6d):
""" """
r6d: np.ndarray, shape (N, 6) r6d: np.ndarray, shape (N, 6)
return: np.ndarray, shape (N, 3), axis-angle vectors return: np.ndarray, shape (N, 3), axis-angle vectors
@@ -39,7 +41,8 @@ def Rotate6D_to_AxisAngle(r6d):
return axis_angle_array return axis_angle_array
def Mat_to_Rotate6D(abs_action):
def mat_to_rotate6d(abs_action):
if len(abs_action.shape) == 2: if len(abs_action.shape) == 2:
return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1) return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1)
elif len(abs_action.shape) == 3: elif len(abs_action.shape) == 3:
@@ -325,7 +325,9 @@ def load_state_dict_with_missing_key_handling(
Returns: Returns:
List of problematic missing keys that weren't in the whitelist. List of problematic missing keys that weren't in the whitelist.
""" """
state_dict['model.vlm.language_model.model.encoder.embed_tokens.weight'] = state_dict['model.vlm.language_model.model.shared.weight'].clone() state_dict["model.vlm.language_model.model.encoder.embed_tokens.weight"] = state_dict[
"model.vlm.language_model.model.shared.weight"
].clone()
# Load the cleaned state dict with strict=False to capture missing/unexpected keys # Load the cleaned state dict with strict=False to capture missing/unexpected keys
load_result = policy.load_state_dict(state_dict, strict=False) load_result = policy.load_state_dict(state_dict, strict=False)
+7 -2
View File
@@ -25,9 +25,10 @@ import torch
from torch import Tensor from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
from lerobot.datasets.factory import IMAGENET_STATS
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION from lerobot.utils.constants import ACTION
from lerobot.datasets.factory import IMAGENET_STATS
from .converters import from_tensor_to_numpy, to_tensor from .converters import from_tensor_to_numpy, to_tensor
from .core import EnvTransition, PolicyAction, TransitionKey from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
@@ -303,7 +304,11 @@ class _NormalizationMixin:
ValueError: If an unsupported normalization mode is encountered. ValueError: If an unsupported normalization mode is encountered.
""" """
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY) norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats and norm_mode != NormalizationMode.IMAGENET: if (
norm_mode == NormalizationMode.IDENTITY
or key not in self._tensor_stats
and norm_mode != NormalizationMode.IMAGENET
):
return tensor return tensor
if norm_mode not in ( if norm_mode not in (
NormalizationMode.MEAN_STD, NormalizationMode.MEAN_STD,
+1
View File
@@ -55,6 +55,7 @@ from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, Transitio
TInput = TypeVar("TInput") TInput = TypeVar("TInput")
TOutput = TypeVar("TOutput") TOutput = TypeVar("TOutput")
class ProcessorStepRegistry: class ProcessorStepRegistry:
"""A registry for ProcessorStep classes to allow instantiation from a string name. """A registry for ProcessorStep classes to allow instantiation from a string name.
+6 -8
View File
@@ -45,6 +45,7 @@ Note that in both examples, the repo/folder should contain at least `config.json
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
""" """
import concurrent.futures as cf import concurrent.futures as cf
import json import json
import logging import logging
@@ -89,6 +90,7 @@ from lerobot.utils.utils import (
inside_slurm, inside_slurm,
) )
def rollout( def rollout(
env: gym.vector.VectorEnv, env: gym.vector.VectorEnv,
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
@@ -153,7 +155,6 @@ def rollout(
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
leave=False, leave=False,
) )
check_env_attributes_and_types(env) check_env_attributes_and_types(env)
while not np.all(done) and step < max_steps: while not np.all(done) and step < max_steps:
# Numpy array to tensor and changing dictionary keys to LeRobot policy format. # Numpy array to tensor and changing dictionary keys to LeRobot policy format.
@@ -164,17 +165,13 @@ def rollout(
# Infer "task" from attributes of environments. # Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv # TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation) observation = add_envs_task(env, observation)
# Preprocess observation (includes image scaling and domain_id addition)
observation = preprocessor(observation) observation = preprocessor(observation)
# Policy inference
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation) action = policy.select_action(observation)
# Postprocess action (includes rotation conversion and device transfer to CPU)
action = postprocessor(action) action = postprocessor(action)
# Convert to numpy
action_numpy: np.ndarray = action.numpy() # Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)" assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action. # Apply the next action.
@@ -500,6 +497,7 @@ def eval_main(cfg: EvalPipelineConfig):
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.") logging.info("Making policy.")
policy = make_policy( policy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
env_cfg=cfg.env, env_cfg=cfg.env,