more changes

This commit is contained in:
Jade Choghari
2025-11-20 14:45:27 +01:00
parent 99b0722425
commit 70582ed226
9 changed files with 181 additions and 24 deletions
+3 -2
View File
@@ -1,8 +1,9 @@
lerobot-eval \ lerobot-eval \
--policy.path="/raid/jade/models/xvla-libero-new_migrated" \ --policy.path="/raid/jade/models/xvla-lib" \
--env.type=libero \ --env.type=libero \
--env.task=libero_spatial \ --env.task=libero_spatial \
--env.action_type=abs \ --env.control_mode=absolute \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--seed=142 --seed=142
+1 -1
View File
@@ -272,7 +272,7 @@ class LiberoEnv(EnvConfig):
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2", LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
} }
) )
action_type: str = "rel" control_mode: str = "relative" # or "absolute"
def __post_init__(self): def __post_init__(self):
if self.obs_type == "pixels": if self.obs_type == "pixels":
+8 -1
View File
@@ -24,6 +24,8 @@ from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_m
from lerobot.processor import ProcessorStep from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import LiberoProcessorStep from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline from lerobot.processor.pipeline import PolicyProcessorPipeline
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.configs.policies import PreTrainedConfig
def make_env_config(env_type: str, **kwargs) -> EnvConfig: def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
def make_env_pre_post_processors( def make_env_pre_post_processors(
env_cfg: EnvConfig, env_cfg: EnvConfig,
policy_cfg: PreTrainedConfig,
) -> tuple[ ) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
@@ -61,10 +64,14 @@ def make_env_pre_post_processors(
# Preprocessor and Postprocessor steps are Identity for most environments # Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = [] preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = [] postprocessor_steps: list[ProcessorStep] = []
if isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor # For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type: if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep()) preprocessor_steps.append(LiberoProcessorStep())
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps) preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps) postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
@@ -136,7 +143,7 @@ def make_env(
init_states=cfg.init_states, init_states=cfg.init_states,
gym_kwargs=cfg.gym_kwargs, gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls, env_cls=env_cls,
action_type=cfg.action_type, control_mode=cfg.control_mode,
) )
elif "metaworld" in cfg.type: elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs from lerobot.envs.metaworld import create_metaworld_envs
+18 -3
View File
@@ -87,7 +87,7 @@ AGENT_POS_HIGH = 1000.0
ACTION_LOW = -1.0 ACTION_LOW = -1.0
ACTION_HIGH = 1.0 ACTION_HIGH = 1.0
TASK_SUITE_MAX_STEPS: dict[str, int] = { TASK_SUITE_MAX_STEPS: dict[str, int] = {
"libero_spatial": 280, # longest training demo has 193 steps "libero_spatial": 800, # longest training demo has 193 steps
"libero_object": 280, # longest training demo has 254 steps "libero_object": 280, # longest training demo has 254 steps
"libero_goal": 300, # longest training demo has 270 steps "libero_goal": 300, # longest training demo has 270 steps
"libero_10": 520, # longest training demo has 505 steps "libero_10": 520, # longest training demo has 505 steps
@@ -114,6 +114,7 @@ class LiberoEnv(gym.Env):
episode_index: int = 0, episode_index: int = 0,
camera_name_mapping: dict[str, str] | None = None, camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10, num_steps_wait: int = 10,
control_mode: str = "relative",
): ):
super().__init__() super().__init__()
self.task_id = task_id self.task_id = task_id
@@ -148,7 +149,7 @@ class LiberoEnv(gym.Env):
self._env = self._make_envs_task(task_suite, self.task_id) self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500 default_steps = 500
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps) self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
self.control_mode = control_mode
images = {} images = {}
for cam in self.camera_name: for cam in self.camera_name:
images[self.camera_name_mapping[cam]] = spaces.Box( images[self.camera_name_mapping[cam]] = spaces.Box(
@@ -239,7 +240,7 @@ class LiberoEnv(gym.Env):
image = raw_obs[camera_name] image = raw_obs[camera_name]
images[self.camera_name_mapping[camera_name]] = image images[self.camera_name_mapping[camera_name]] = image
eef_pos = raw_obs.get("robot0_eef_pos") eef_pos = self._env.robots[0].controller.ee_pos #raw_obs.get("robot0_eef_pos")
eef_quat = raw_obs.get("robot0_eef_quat") eef_quat = raw_obs.get("robot0_eef_quat")
# rotation matrix from controller # rotation matrix from controller
@@ -296,6 +297,15 @@ class LiberoEnv(gym.Env):
# Increasing this value can improve determinism and reproducibility across resets. # Increasing this value can improve determinism and reproducibility across resets.
for _ in range(self.num_steps_wait): for _ in range(self.num_steps_wait):
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action()) raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
if self.control_mode == "absolute":
for robot in self._env.robots:
robot.controller.use_delta = False
elif self.control_mode == "relative":
for robot in self._env.robots:
robot.controller.use_delta = True
else:
raise ValueError(f"Invalid control mode: {self.control_mode}")
observation = self._format_raw_obs(raw_obs) observation = self._format_raw_obs(raw_obs)
info = {"is_success": False} info = {"is_success": False}
return observation, info return observation, info
@@ -343,6 +353,7 @@ def _make_env_fns(
camera_names: list[str], camera_names: list[str],
init_states: bool, init_states: bool,
gym_kwargs: Mapping[str, Any], gym_kwargs: Mapping[str, Any],
control_mode: str,
) -> list[Callable[[], LiberoEnv]]: ) -> list[Callable[[], LiberoEnv]]:
"""Build n_envs factory callables for a single (suite, task_id).""" """Build n_envs factory callables for a single (suite, task_id)."""
@@ -355,6 +366,7 @@ def _make_env_fns(
camera_name=camera_names, camera_name=camera_names,
init_states=init_states, init_states=init_states,
episode_index=episode_index, episode_index=episode_index,
control_mode=control_mode,
**local_kwargs, **local_kwargs,
) )
@@ -374,6 +386,7 @@ def create_libero_envs(
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image", camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
init_states: bool = True, init_states: bool = True,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None, env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
control_mode: str = "relative",
) -> dict[str, dict[int, Any]]: ) -> dict[str, dict[int, Any]]:
""" """
Create vectorized LIBERO environments with a consistent return shape. Create vectorized LIBERO environments with a consistent return shape.
@@ -409,6 +422,7 @@ def create_libero_envs(
suite = _get_suite(suite_name) suite = _get_suite(suite_name)
total = len(suite.tasks) total = len(suite.tasks)
selected = _select_task_ids(total, task_ids_filter) selected = _select_task_ids(total, task_ids_filter)
selected = [0]
if not selected: if not selected:
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).") raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
@@ -421,6 +435,7 @@ def create_libero_envs(
camera_names=camera_names, camera_names=camera_names,
init_states=init_states, init_states=init_states,
gym_kwargs=gym_kwargs, gym_kwargs=gym_kwargs,
control_mode=control_mode,
) )
out[suite_name][tid] = env_cls(fns) out[suite_name][tid] = env_cls(fns)
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}") print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
-6
View File
@@ -1,6 +0,0 @@
from lerobot.policies.xvla.processor_xvla import (
XVLAAddDomainIdProcessorStep,
XVLAImageScaleProcessorStep,
XVLARotation6DToAxisAngleProcessorStep,
make_xvla_pre_post_processors,
)
+5 -5
View File
@@ -315,7 +315,6 @@ class XVLAPolicy(PreTrainedPolicy):
return total_loss, log_dict return total_loss, log_dict
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
print("get_action_chunk")
inputs = self._build_model_inputs(batch) inputs = self._build_model_inputs(batch)
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps) actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
actions = self._trim_action_dim(actions) actions = self._trim_action_dim(actions)
@@ -361,7 +360,7 @@ class XVLAPolicy(PreTrainedPolicy):
""" """
import safetensors.torch import safetensors.torch
# --- Step 1: Load config --- # step 1: load config
if config is None: if config is None:
config = PreTrainedConfig.from_pretrained( config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path, pretrained_name_or_path=pretrained_name_or_path,
@@ -377,7 +376,7 @@ class XVLAPolicy(PreTrainedPolicy):
model_id = str(pretrained_name_or_path) model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs) instance = cls(config, **kwargs)
# --- Step 2: Locate model.safetensors --- # step 2: locate model.safetensors
if os.path.isdir(model_id): if os.path.isdir(model_id):
print("Loading weights from local directory") print("Loading weights from local directory")
model_file = os.path.join(model_id, "model.safetensors") model_file = os.path.join(model_id, "model.safetensors")
@@ -401,13 +400,14 @@ class XVLAPolicy(PreTrainedPolicy):
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
print(f"Loading checkpoint from {model_file}") print(f"Loading checkpoint from {model_file}")
# step 3: load state dict
state_dict = safetensors.torch.load_file(model_file) state_dict = safetensors.torch.load_file(model_file)
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight" encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
shared_key = "model.vlm.language_model.model.shared.weight" shared_key = "model.vlm.language_model.model.shared.weight"
if encoder_key in state_dict: if encoder_key in state_dict:
state_dict[shared_key] = state_dict[encoder_key] state_dict[shared_key] = state_dict[encoder_key]
# or deepcopy # or deepcopy
# step 5: load into instance # step 4: load into instance
missing, unexpected = instance.load_state_dict(state_dict, strict=True) missing, unexpected = instance.load_state_dict(state_dict, strict=True)
print("Loaded XVLA checkpoint") print("Loaded XVLA checkpoint")
if missing: if missing:
@@ -415,7 +415,7 @@ class XVLAPolicy(PreTrainedPolicy):
if unexpected: if unexpected:
print(f"Unexpected keys: {unexpected}") print(f"Unexpected keys: {unexpected}")
# step 6: finalize # step 5: finalize
instance.to(config.device) instance.to(config.device)
instance.eval() instance.eval()
return instance return instance
+145 -2
View File
@@ -28,6 +28,7 @@ from lerobot.processor import (
NormalizerProcessorStep, NormalizerProcessorStep,
PolicyAction, PolicyAction,
PolicyProcessorPipeline, PolicyProcessorPipeline,
ObservationProcessorStep,
ProcessorStep, ProcessorStep,
ProcessorStepRegistry, ProcessorStepRegistry,
RenameObservationsProcessorStep, RenameObservationsProcessorStep,
@@ -36,8 +37,8 @@ from lerobot.processor import (
) )
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, OBS_STATE, OBS_IMAGES
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
def make_xvla_pre_post_processors( def make_xvla_pre_post_processors(
config: XVLAConfig, config: XVLAConfig,
@@ -89,6 +90,127 @@ def make_xvla_pre_post_processors(
# Custom XVLA processor steps # Custom XVLA processor steps
@dataclass
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
if key == f"{OBS_IMAGES}.image":
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_mat = robot_state["eef"]["mat"] # (B, 3, 3)
eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6)
extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device)
proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10)
state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(20,),
dtype="float32",
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor:
"""
Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6).
Args:
rot_mats (Tensor): Rotation matrices of shape (B, 3, 3)
Returns:
Tensor: 6D rotation representation, shape (B, 6)
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 3, 3)
"""
if not isinstance(rot_mats, torch.Tensor):
raise TypeError(
f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}"
)
if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3):
raise ValueError(
f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}"
)
rot_mats = rot_mats.to(torch.float32)
col1 = rot_mats[:, :3, 0] # (B, 3)
col2 = rot_mats[:, :3, 1] # (B, 3)
rot6d = torch.cat([col1, col2], dim=-1) # (B, 6)
return rot6d
def observation(self, observation):
return self._process_observation(observation)
@dataclass @dataclass
@@ -254,3 +376,24 @@ class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
return { return {
"expected_action_dim": self.expected_action_dim, "expected_action_dim": self.expected_action_dim,
} }
def make_xvla_libero_pre_post_processors(
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Build the LeRobot processor pipelines for XVLA with LIBERO environment.
"""
pre_processor_steps: list[ProcessorStep] = []
post_processor_steps: list[ProcessorStep] = []
pre_processor_steps.extend([LiberoProcessorStep(), XVLAImageScaleProcessorStep(), XVLAAddDomainIdProcessorStep()])
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=pre_processor_steps,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=post_processor_steps,
),
)
@@ -328,7 +328,6 @@ class _NormalizationMixin:
if norm_mode == NormalizationMode.IMAGENET: if norm_mode == NormalizationMode.IMAGENET:
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype) mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype) std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW) # Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
while mean.dim() < tensor.dim(): while mean.dim() < tensor.dim():
mean = mean.unsqueeze(0) mean = mean.unsqueeze(0)
+1 -3
View File
@@ -175,11 +175,9 @@ def rollout(
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation) action = policy.select_action(observation)
action = postprocessor(action) action = postprocessor(action)
action_transition = {"action": action} action_transition = {"action": action}
action_transition = env_postprocessor(action_transition) action_transition = env_postprocessor(action_transition)
action = action_transition["action"] action = action_transition["action"]
# Convert to CPU / numpy. # Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy() action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)" assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
@@ -533,7 +531,7 @@ def eval_main(cfg: EvalPipelineConfig):
) )
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments) # Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env) env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all( info = eval_policy_all(