mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
more changes
This commit is contained in:
@@ -1,8 +1,9 @@
|
|||||||
lerobot-eval \
|
lerobot-eval \
|
||||||
--policy.path="/raid/jade/models/xvla-libero-new_migrated" \
|
--policy.path="/raid/jade/models/xvla-lib" \
|
||||||
--env.type=libero \
|
--env.type=libero \
|
||||||
--env.task=libero_spatial \
|
--env.task=libero_spatial \
|
||||||
--env.action_type=abs \
|
--env.control_mode=absolute \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--seed=142
|
--seed=142
|
||||||
|
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ class LiberoEnv(EnvConfig):
|
|||||||
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
action_type: str = "rel"
|
control_mode: str = "relative" # or "absolute"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.obs_type == "pixels":
|
if self.obs_type == "pixels":
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_m
|
|||||||
from lerobot.processor import ProcessorStep
|
from lerobot.processor import ProcessorStep
|
||||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
|
||||||
|
|
||||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||||
@@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
|||||||
|
|
||||||
def make_env_pre_post_processors(
|
def make_env_pre_post_processors(
|
||||||
env_cfg: EnvConfig,
|
env_cfg: EnvConfig,
|
||||||
|
policy_cfg: PreTrainedConfig,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
@@ -61,10 +64,14 @@ def make_env_pre_post_processors(
|
|||||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||||
preprocessor_steps: list[ProcessorStep] = []
|
preprocessor_steps: list[ProcessorStep] = []
|
||||||
postprocessor_steps: list[ProcessorStep] = []
|
postprocessor_steps: list[ProcessorStep] = []
|
||||||
|
if isinstance(policy_cfg, XVLAConfig):
|
||||||
|
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||||
|
return make_xvla_libero_pre_post_processors()
|
||||||
|
|
||||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||||
preprocessor_steps.append(LiberoProcessorStep())
|
preprocessor_steps.append(LiberoProcessorStep())
|
||||||
|
|
||||||
|
|
||||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
|
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
|
||||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
|
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
|
||||||
@@ -136,7 +143,7 @@ def make_env(
|
|||||||
init_states=cfg.init_states,
|
init_states=cfg.init_states,
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
env_cls=env_cls,
|
env_cls=env_cls,
|
||||||
action_type=cfg.action_type,
|
control_mode=cfg.control_mode,
|
||||||
)
|
)
|
||||||
elif "metaworld" in cfg.type:
|
elif "metaworld" in cfg.type:
|
||||||
from lerobot.envs.metaworld import create_metaworld_envs
|
from lerobot.envs.metaworld import create_metaworld_envs
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ AGENT_POS_HIGH = 1000.0
|
|||||||
ACTION_LOW = -1.0
|
ACTION_LOW = -1.0
|
||||||
ACTION_HIGH = 1.0
|
ACTION_HIGH = 1.0
|
||||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||||
"libero_spatial": 280, # longest training demo has 193 steps
|
"libero_spatial": 800, # longest training demo has 193 steps
|
||||||
"libero_object": 280, # longest training demo has 254 steps
|
"libero_object": 280, # longest training demo has 254 steps
|
||||||
"libero_goal": 300, # longest training demo has 270 steps
|
"libero_goal": 300, # longest training demo has 270 steps
|
||||||
"libero_10": 520, # longest training demo has 505 steps
|
"libero_10": 520, # longest training demo has 505 steps
|
||||||
@@ -114,6 +114,7 @@ class LiberoEnv(gym.Env):
|
|||||||
episode_index: int = 0,
|
episode_index: int = 0,
|
||||||
camera_name_mapping: dict[str, str] | None = None,
|
camera_name_mapping: dict[str, str] | None = None,
|
||||||
num_steps_wait: int = 10,
|
num_steps_wait: int = 10,
|
||||||
|
control_mode: str = "relative",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.task_id = task_id
|
self.task_id = task_id
|
||||||
@@ -148,7 +149,7 @@ class LiberoEnv(gym.Env):
|
|||||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||||
default_steps = 500
|
default_steps = 500
|
||||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||||
|
self.control_mode = control_mode
|
||||||
images = {}
|
images = {}
|
||||||
for cam in self.camera_name:
|
for cam in self.camera_name:
|
||||||
images[self.camera_name_mapping[cam]] = spaces.Box(
|
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||||
@@ -239,7 +240,7 @@ class LiberoEnv(gym.Env):
|
|||||||
image = raw_obs[camera_name]
|
image = raw_obs[camera_name]
|
||||||
images[self.camera_name_mapping[camera_name]] = image
|
images[self.camera_name_mapping[camera_name]] = image
|
||||||
|
|
||||||
eef_pos = raw_obs.get("robot0_eef_pos")
|
eef_pos = self._env.robots[0].controller.ee_pos #raw_obs.get("robot0_eef_pos")
|
||||||
eef_quat = raw_obs.get("robot0_eef_quat")
|
eef_quat = raw_obs.get("robot0_eef_quat")
|
||||||
|
|
||||||
# rotation matrix from controller
|
# rotation matrix from controller
|
||||||
@@ -296,6 +297,15 @@ class LiberoEnv(gym.Env):
|
|||||||
# Increasing this value can improve determinism and reproducibility across resets.
|
# Increasing this value can improve determinism and reproducibility across resets.
|
||||||
for _ in range(self.num_steps_wait):
|
for _ in range(self.num_steps_wait):
|
||||||
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||||
|
|
||||||
|
if self.control_mode == "absolute":
|
||||||
|
for robot in self._env.robots:
|
||||||
|
robot.controller.use_delta = False
|
||||||
|
elif self.control_mode == "relative":
|
||||||
|
for robot in self._env.robots:
|
||||||
|
robot.controller.use_delta = True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid control mode: {self.control_mode}")
|
||||||
observation = self._format_raw_obs(raw_obs)
|
observation = self._format_raw_obs(raw_obs)
|
||||||
info = {"is_success": False}
|
info = {"is_success": False}
|
||||||
return observation, info
|
return observation, info
|
||||||
@@ -343,6 +353,7 @@ def _make_env_fns(
|
|||||||
camera_names: list[str],
|
camera_names: list[str],
|
||||||
init_states: bool,
|
init_states: bool,
|
||||||
gym_kwargs: Mapping[str, Any],
|
gym_kwargs: Mapping[str, Any],
|
||||||
|
control_mode: str,
|
||||||
) -> list[Callable[[], LiberoEnv]]:
|
) -> list[Callable[[], LiberoEnv]]:
|
||||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||||
|
|
||||||
@@ -355,6 +366,7 @@ def _make_env_fns(
|
|||||||
camera_name=camera_names,
|
camera_name=camera_names,
|
||||||
init_states=init_states,
|
init_states=init_states,
|
||||||
episode_index=episode_index,
|
episode_index=episode_index,
|
||||||
|
control_mode=control_mode,
|
||||||
**local_kwargs,
|
**local_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -374,6 +386,7 @@ def create_libero_envs(
|
|||||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||||
init_states: bool = True,
|
init_states: bool = True,
|
||||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||||
|
control_mode: str = "relative",
|
||||||
) -> dict[str, dict[int, Any]]:
|
) -> dict[str, dict[int, Any]]:
|
||||||
"""
|
"""
|
||||||
Create vectorized LIBERO environments with a consistent return shape.
|
Create vectorized LIBERO environments with a consistent return shape.
|
||||||
@@ -409,6 +422,7 @@ def create_libero_envs(
|
|||||||
suite = _get_suite(suite_name)
|
suite = _get_suite(suite_name)
|
||||||
total = len(suite.tasks)
|
total = len(suite.tasks)
|
||||||
selected = _select_task_ids(total, task_ids_filter)
|
selected = _select_task_ids(total, task_ids_filter)
|
||||||
|
selected = [0]
|
||||||
if not selected:
|
if not selected:
|
||||||
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
||||||
|
|
||||||
@@ -421,6 +435,7 @@ def create_libero_envs(
|
|||||||
camera_names=camera_names,
|
camera_names=camera_names,
|
||||||
init_states=init_states,
|
init_states=init_states,
|
||||||
gym_kwargs=gym_kwargs,
|
gym_kwargs=gym_kwargs,
|
||||||
|
control_mode=control_mode,
|
||||||
)
|
)
|
||||||
out[suite_name][tid] = env_cls(fns)
|
out[suite_name][tid] = env_cls(fns)
|
||||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
from lerobot.policies.xvla.processor_xvla import (
|
|
||||||
XVLAAddDomainIdProcessorStep,
|
|
||||||
XVLAImageScaleProcessorStep,
|
|
||||||
XVLARotation6DToAxisAngleProcessorStep,
|
|
||||||
make_xvla_pre_post_processors,
|
|
||||||
)
|
|
||||||
@@ -315,7 +315,6 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
return total_loss, log_dict
|
return total_loss, log_dict
|
||||||
|
|
||||||
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
print("get_action_chunk")
|
|
||||||
inputs = self._build_model_inputs(batch)
|
inputs = self._build_model_inputs(batch)
|
||||||
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
|
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
|
||||||
actions = self._trim_action_dim(actions)
|
actions = self._trim_action_dim(actions)
|
||||||
@@ -361,7 +360,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
"""
|
"""
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
# --- Step 1: Load config ---
|
# step 1: load config
|
||||||
if config is None:
|
if config is None:
|
||||||
config = PreTrainedConfig.from_pretrained(
|
config = PreTrainedConfig.from_pretrained(
|
||||||
pretrained_name_or_path=pretrained_name_or_path,
|
pretrained_name_or_path=pretrained_name_or_path,
|
||||||
@@ -377,7 +376,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
model_id = str(pretrained_name_or_path)
|
model_id = str(pretrained_name_or_path)
|
||||||
instance = cls(config, **kwargs)
|
instance = cls(config, **kwargs)
|
||||||
# --- Step 2: Locate model.safetensors ---
|
# step 2: locate model.safetensors
|
||||||
if os.path.isdir(model_id):
|
if os.path.isdir(model_id):
|
||||||
print("Loading weights from local directory")
|
print("Loading weights from local directory")
|
||||||
model_file = os.path.join(model_id, "model.safetensors")
|
model_file = os.path.join(model_id, "model.safetensors")
|
||||||
@@ -401,13 +400,14 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
||||||
|
|
||||||
print(f"Loading checkpoint from {model_file}")
|
print(f"Loading checkpoint from {model_file}")
|
||||||
|
# step 3: load state dict
|
||||||
state_dict = safetensors.torch.load_file(model_file)
|
state_dict = safetensors.torch.load_file(model_file)
|
||||||
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
||||||
shared_key = "model.vlm.language_model.model.shared.weight"
|
shared_key = "model.vlm.language_model.model.shared.weight"
|
||||||
if encoder_key in state_dict:
|
if encoder_key in state_dict:
|
||||||
state_dict[shared_key] = state_dict[encoder_key]
|
state_dict[shared_key] = state_dict[encoder_key]
|
||||||
# or deepcopy
|
# or deepcopy
|
||||||
# step 5: load into instance
|
# step 4: load into instance
|
||||||
missing, unexpected = instance.load_state_dict(state_dict, strict=True)
|
missing, unexpected = instance.load_state_dict(state_dict, strict=True)
|
||||||
print("Loaded XVLA checkpoint")
|
print("Loaded XVLA checkpoint")
|
||||||
if missing:
|
if missing:
|
||||||
@@ -415,7 +415,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
if unexpected:
|
if unexpected:
|
||||||
print(f"Unexpected keys: {unexpected}")
|
print(f"Unexpected keys: {unexpected}")
|
||||||
|
|
||||||
# step 6: finalize
|
# step 5: finalize
|
||||||
instance.to(config.device)
|
instance.to(config.device)
|
||||||
instance.eval()
|
instance.eval()
|
||||||
return instance
|
return instance
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from lerobot.processor import (
|
|||||||
NormalizerProcessorStep,
|
NormalizerProcessorStep,
|
||||||
PolicyAction,
|
PolicyAction,
|
||||||
PolicyProcessorPipeline,
|
PolicyProcessorPipeline,
|
||||||
|
ObservationProcessorStep,
|
||||||
ProcessorStep,
|
ProcessorStep,
|
||||||
ProcessorStepRegistry,
|
ProcessorStepRegistry,
|
||||||
RenameObservationsProcessorStep,
|
RenameObservationsProcessorStep,
|
||||||
@@ -36,8 +37,8 @@ from lerobot.processor import (
|
|||||||
)
|
)
|
||||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, OBS_STATE, OBS_IMAGES
|
||||||
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
|
|
||||||
def make_xvla_pre_post_processors(
|
def make_xvla_pre_post_processors(
|
||||||
config: XVLAConfig,
|
config: XVLAConfig,
|
||||||
@@ -89,6 +90,127 @@ def make_xvla_pre_post_processors(
|
|||||||
|
|
||||||
|
|
||||||
# Custom XVLA processor steps
|
# Custom XVLA processor steps
|
||||||
|
@dataclass
|
||||||
|
class LiberoProcessorStep(ObservationProcessorStep):
|
||||||
|
"""
|
||||||
|
Processes LIBERO observations into the LeRobot format.
|
||||||
|
|
||||||
|
This step handles the specific observation structure from LIBERO environments,
|
||||||
|
which includes nested robot_state dictionaries and image observations.
|
||||||
|
|
||||||
|
**State Processing:**
|
||||||
|
- Processes the `robot_state` dictionary which contains nested end-effector,
|
||||||
|
gripper, and joint information.
|
||||||
|
- Extracts and concatenates:
|
||||||
|
- End-effector position (3D)
|
||||||
|
- End-effector quaternion converted to axis-angle (3D)
|
||||||
|
- Gripper joint positions (2D)
|
||||||
|
- Maps the concatenated state to `"observation.state"`.
|
||||||
|
|
||||||
|
**Image Processing:**
|
||||||
|
- Rotates images by 180 degrees by flipping both height and width dimensions.
|
||||||
|
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _process_observation(self, observation):
|
||||||
|
"""
|
||||||
|
Processes both image and robot_state observations from LIBERO.
|
||||||
|
"""
|
||||||
|
processed_obs = observation.copy()
|
||||||
|
for key in list(processed_obs.keys()):
|
||||||
|
if key.startswith(f"{OBS_IMAGES}."):
|
||||||
|
img = processed_obs[key]
|
||||||
|
|
||||||
|
if key == f"{OBS_IMAGES}.image":
|
||||||
|
# Flip both H and W
|
||||||
|
img = torch.flip(img, dims=[2, 3])
|
||||||
|
|
||||||
|
processed_obs[key] = img
|
||||||
|
# Process robot_state into a flat state vector
|
||||||
|
if "observation.robot_state" in processed_obs:
|
||||||
|
robot_state = processed_obs.pop("observation.robot_state")
|
||||||
|
|
||||||
|
# Extract components
|
||||||
|
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
|
||||||
|
eef_mat = robot_state["eef"]["mat"] # (B, 3, 3)
|
||||||
|
eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6)
|
||||||
|
|
||||||
|
extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device)
|
||||||
|
|
||||||
|
proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10)
|
||||||
|
state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20)
|
||||||
|
# ensure float32
|
||||||
|
state = state.float()
|
||||||
|
if state.dim() == 1:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
|
||||||
|
processed_obs[OBS_STATE] = state
|
||||||
|
return processed_obs
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""
|
||||||
|
Transforms feature keys from the LIBERO format to the LeRobot standard.
|
||||||
|
"""
|
||||||
|
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
|
||||||
|
|
||||||
|
# copy over non-STATE features
|
||||||
|
for ft, feats in features.items():
|
||||||
|
if ft != PipelineFeatureType.STATE:
|
||||||
|
new_features[ft] = feats.copy()
|
||||||
|
|
||||||
|
# rebuild STATE features
|
||||||
|
state_feats = {}
|
||||||
|
|
||||||
|
# add our new flattened state
|
||||||
|
state_feats["observation.state"] = PolicyFeature(
|
||||||
|
key="observation.state",
|
||||||
|
shape=(20,),
|
||||||
|
dtype="float32",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_features[PipelineFeatureType.STATE] = state_feats
|
||||||
|
|
||||||
|
return new_features
|
||||||
|
|
||||||
|
def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rot_mats (Tensor): Rotation matrices of shape (B, 3, 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: 6D rotation representation, shape (B, 6)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if input is not a torch tensor
|
||||||
|
ValueError: if shape is not (B, 3, 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(rot_mats, torch.Tensor):
|
||||||
|
raise TypeError(
|
||||||
|
f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3):
|
||||||
|
raise ValueError(
|
||||||
|
f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
rot_mats = rot_mats.to(torch.float32)
|
||||||
|
|
||||||
|
col1 = rot_mats[:, :3, 0] # (B, 3)
|
||||||
|
col2 = rot_mats[:, :3, 1] # (B, 3)
|
||||||
|
|
||||||
|
rot6d = torch.cat([col1, col2], dim=-1) # (B, 6)
|
||||||
|
|
||||||
|
return rot6d
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
return self._process_observation(observation)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -254,3 +376,24 @@ class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
|
|||||||
return {
|
return {
|
||||||
"expected_action_dim": self.expected_action_dim,
|
"expected_action_dim": self.expected_action_dim,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def make_xvla_libero_pre_post_processors(
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Build the LeRobot processor pipelines for XVLA with LIBERO environment.
|
||||||
|
"""
|
||||||
|
pre_processor_steps: list[ProcessorStep] = []
|
||||||
|
post_processor_steps: list[ProcessorStep] = []
|
||||||
|
pre_processor_steps.extend([LiberoProcessorStep(), XVLAImageScaleProcessorStep(), XVLAAddDomainIdProcessorStep()])
|
||||||
|
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=pre_processor_steps,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=post_processor_steps,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
@@ -328,7 +328,6 @@ class _NormalizationMixin:
|
|||||||
if norm_mode == NormalizationMode.IMAGENET:
|
if norm_mode == NormalizationMode.IMAGENET:
|
||||||
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
|
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
|
||||||
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
|
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
|
||||||
|
|
||||||
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
|
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
|
||||||
while mean.dim() < tensor.dim():
|
while mean.dim() < tensor.dim():
|
||||||
mean = mean.unsqueeze(0)
|
mean = mean.unsqueeze(0)
|
||||||
|
|||||||
@@ -175,11 +175,9 @@ def rollout(
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
action = postprocessor(action)
|
action = postprocessor(action)
|
||||||
|
|
||||||
action_transition = {"action": action}
|
action_transition = {"action": action}
|
||||||
action_transition = env_postprocessor(action_transition)
|
action_transition = env_postprocessor(action_transition)
|
||||||
action = action_transition["action"]
|
action = action_transition["action"]
|
||||||
|
|
||||||
# Convert to CPU / numpy.
|
# Convert to CPU / numpy.
|
||||||
action_numpy: np.ndarray = action.to("cpu").numpy()
|
action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||||
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||||
@@ -533,7 +531,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
info = eval_policy_all(
|
info = eval_policy_all(
|
||||||
|
|||||||
Reference in New Issue
Block a user