From 70582ed226001109adb4b35e24ca26bc4aa67d24 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Thu, 20 Nov 2025 14:45:27 +0100 Subject: [PATCH] more changes --- eval.sh | 5 +- src/lerobot/envs/configs.py | 2 +- src/lerobot/envs/factory.py | 9 +- src/lerobot/envs/libero.py | 21 ++- src/lerobot/policies/xvla/__init__.py | 6 - src/lerobot/policies/xvla/modeling_xvla.py | 10 +- src/lerobot/policies/xvla/processor_xvla.py | 147 ++++++++++++++++++- src/lerobot/processor/normalize_processor.py | 1 - src/lerobot/scripts/lerobot_eval.py | 4 +- 9 files changed, 181 insertions(+), 24 deletions(-) delete mode 100644 src/lerobot/policies/xvla/__init__.py diff --git a/eval.sh b/eval.sh index c4ce178e4..318ed3fbc 100644 --- a/eval.sh +++ b/eval.sh @@ -1,8 +1,9 @@ lerobot-eval \ - --policy.path="/raid/jade/models/xvla-libero-new_migrated" \ + --policy.path="/raid/jade/models/xvla-lib" \ --env.type=libero \ --env.task=libero_spatial \ - --env.action_type=abs \ + --env.control_mode=absolute \ --eval.batch_size=1 \ --eval.n_episodes=1 \ --seed=142 + diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index a7fc800aa..14ca2e6a7 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -272,7 +272,7 @@ class LiberoEnv(EnvConfig): LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2", } ) - action_type: str = "rel" + control_mode: str = "relative" # or "absolute" def __post_init__(self): if self.obs_type == "pixels": diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index 397d513b0..668ba6406 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -24,6 +24,8 @@ from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_m from lerobot.processor import ProcessorStep from lerobot.processor.env_processor import LiberoProcessorStep from lerobot.processor.pipeline import PolicyProcessorPipeline +from lerobot.policies.xvla.configuration_xvla import XVLAConfig +from lerobot.configs.policies import PreTrainedConfig def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: def make_env_pre_post_processors( env_cfg: EnvConfig, + policy_cfg: PreTrainedConfig, ) -> tuple[ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], @@ -61,10 +64,14 @@ def make_env_pre_post_processors( # Preprocessor and Postprocessor steps are Identity for most environments preprocessor_steps: list[ProcessorStep] = [] postprocessor_steps: list[ProcessorStep] = [] + if isinstance(policy_cfg, XVLAConfig): + from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors + return make_xvla_libero_pre_post_processors() # For LIBERO environments, add the LiberoProcessorStep to preprocessor if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type: preprocessor_steps.append(LiberoProcessorStep()) + preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps) postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps) @@ -136,7 +143,7 @@ def make_env( init_states=cfg.init_states, gym_kwargs=cfg.gym_kwargs, env_cls=env_cls, - action_type=cfg.action_type, + control_mode=cfg.control_mode, ) elif "metaworld" in cfg.type: from lerobot.envs.metaworld import create_metaworld_envs diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 35bc58e07..0e182524c 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -87,7 +87,7 @@ AGENT_POS_HIGH = 1000.0 ACTION_LOW = -1.0 ACTION_HIGH = 1.0 TASK_SUITE_MAX_STEPS: dict[str, int] = { - "libero_spatial": 280, # longest training demo has 193 steps + "libero_spatial": 800, # longest training demo has 193 steps "libero_object": 280, # longest training demo has 254 steps "libero_goal": 300, # longest training demo has 270 steps "libero_10": 520, # longest training demo has 505 steps @@ -114,6 +114,7 @@ class LiberoEnv(gym.Env): episode_index: int = 0, camera_name_mapping: dict[str, str] | None = None, num_steps_wait: int = 10, + control_mode: str = "relative", ): super().__init__() self.task_id = task_id @@ -148,7 +149,7 @@ class LiberoEnv(gym.Env): self._env = self._make_envs_task(task_suite, self.task_id) default_steps = 500 self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps) - + self.control_mode = control_mode images = {} for cam in self.camera_name: images[self.camera_name_mapping[cam]] = spaces.Box( @@ -239,7 +240,7 @@ class LiberoEnv(gym.Env): image = raw_obs[camera_name] images[self.camera_name_mapping[camera_name]] = image - eef_pos = raw_obs.get("robot0_eef_pos") + eef_pos = self._env.robots[0].controller.ee_pos #raw_obs.get("robot0_eef_pos") eef_quat = raw_obs.get("robot0_eef_quat") # rotation matrix from controller @@ -296,6 +297,15 @@ class LiberoEnv(gym.Env): # Increasing this value can improve determinism and reproducibility across resets. for _ in range(self.num_steps_wait): raw_obs, _, _, _ = self._env.step(get_libero_dummy_action()) + + if self.control_mode == "absolute": + for robot in self._env.robots: + robot.controller.use_delta = False + elif self.control_mode == "relative": + for robot in self._env.robots: + robot.controller.use_delta = True + else: + raise ValueError(f"Invalid control mode: {self.control_mode}") observation = self._format_raw_obs(raw_obs) info = {"is_success": False} return observation, info @@ -343,6 +353,7 @@ def _make_env_fns( camera_names: list[str], init_states: bool, gym_kwargs: Mapping[str, Any], + control_mode: str, ) -> list[Callable[[], LiberoEnv]]: """Build n_envs factory callables for a single (suite, task_id).""" @@ -355,6 +366,7 @@ def _make_env_fns( camera_name=camera_names, init_states=init_states, episode_index=episode_index, + control_mode=control_mode, **local_kwargs, ) @@ -374,6 +386,7 @@ def create_libero_envs( camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image", init_states: bool = True, env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None, + control_mode: str = "relative", ) -> dict[str, dict[int, Any]]: """ Create vectorized LIBERO environments with a consistent return shape. @@ -409,6 +422,7 @@ def create_libero_envs( suite = _get_suite(suite_name) total = len(suite.tasks) selected = _select_task_ids(total, task_ids_filter) + selected = [0] if not selected: raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).") @@ -421,6 +435,7 @@ def create_libero_envs( camera_names=camera_names, init_states=init_states, gym_kwargs=gym_kwargs, + control_mode=control_mode, ) out[suite_name][tid] = env_cls(fns) print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}") diff --git a/src/lerobot/policies/xvla/__init__.py b/src/lerobot/policies/xvla/__init__.py deleted file mode 100644 index 822a4797e..000000000 --- a/src/lerobot/policies/xvla/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from lerobot.policies.xvla.processor_xvla import ( - XVLAAddDomainIdProcessorStep, - XVLAImageScaleProcessorStep, - XVLARotation6DToAxisAngleProcessorStep, - make_xvla_pre_post_processors, -) diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index 015d173fb..26454ab24 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -315,7 +315,6 @@ class XVLAPolicy(PreTrainedPolicy): return total_loss, log_dict def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: - print("get_action_chunk") inputs = self._build_model_inputs(batch) actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps) actions = self._trim_action_dim(actions) @@ -361,7 +360,7 @@ class XVLAPolicy(PreTrainedPolicy): """ import safetensors.torch - # --- Step 1: Load config --- + # step 1: load config if config is None: config = PreTrainedConfig.from_pretrained( pretrained_name_or_path=pretrained_name_or_path, @@ -377,7 +376,7 @@ class XVLAPolicy(PreTrainedPolicy): model_id = str(pretrained_name_or_path) instance = cls(config, **kwargs) - # --- Step 2: Locate model.safetensors --- + # step 2: locate model.safetensors if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, "model.safetensors") @@ -401,13 +400,14 @@ class XVLAPolicy(PreTrainedPolicy): raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e print(f"Loading checkpoint from {model_file}") + # step 3: load state dict state_dict = safetensors.torch.load_file(model_file) encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight" shared_key = "model.vlm.language_model.model.shared.weight" if encoder_key in state_dict: state_dict[shared_key] = state_dict[encoder_key] # or deepcopy - # step 5: load into instance + # step 4: load into instance missing, unexpected = instance.load_state_dict(state_dict, strict=True) print("Loaded XVLA checkpoint") if missing: @@ -415,7 +415,7 @@ class XVLAPolicy(PreTrainedPolicy): if unexpected: print(f"Unexpected keys: {unexpected}") - # step 6: finalize + # step 5: finalize instance.to(config.device) instance.eval() return instance diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py index d8a6e7092..77dbc7cfa 100644 --- a/src/lerobot/policies/xvla/processor_xvla.py +++ b/src/lerobot/policies/xvla/processor_xvla.py @@ -28,6 +28,7 @@ from lerobot.processor import ( NormalizerProcessorStep, PolicyAction, PolicyProcessorPipeline, + ObservationProcessorStep, ProcessorStep, ProcessorStepRegistry, RenameObservationsProcessorStep, @@ -36,8 +37,8 @@ from lerobot.processor import ( ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.processor.core import EnvTransition, TransitionKey -from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME - +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, OBS_STATE, OBS_IMAGES +from lerobot.configs.types import PipelineFeatureType, PolicyFeature def make_xvla_pre_post_processors( config: XVLAConfig, @@ -89,6 +90,127 @@ def make_xvla_pre_post_processors( # Custom XVLA processor steps +@dataclass +class LiberoProcessorStep(ObservationProcessorStep): + """ + Processes LIBERO observations into the LeRobot format. + + This step handles the specific observation structure from LIBERO environments, + which includes nested robot_state dictionaries and image observations. + + **State Processing:** + - Processes the `robot_state` dictionary which contains nested end-effector, + gripper, and joint information. + - Extracts and concatenates: + - End-effector position (3D) + - End-effector quaternion converted to axis-angle (3D) + - Gripper joint positions (2D) + - Maps the concatenated state to `"observation.state"`. + + **Image Processing:** + - Rotates images by 180 degrees by flipping both height and width dimensions. + - This accounts for the HuggingFaceVLA/libero camera orientation convention. + """ + + def _process_observation(self, observation): + """ + Processes both image and robot_state observations from LIBERO. + """ + processed_obs = observation.copy() + for key in list(processed_obs.keys()): + if key.startswith(f"{OBS_IMAGES}."): + img = processed_obs[key] + + if key == f"{OBS_IMAGES}.image": + # Flip both H and W + img = torch.flip(img, dims=[2, 3]) + + processed_obs[key] = img + # Process robot_state into a flat state vector + if "observation.robot_state" in processed_obs: + robot_state = processed_obs.pop("observation.robot_state") + + # Extract components + eef_pos = robot_state["eef"]["pos"] # (B, 3,) + eef_mat = robot_state["eef"]["mat"] # (B, 3, 3) + eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6) + + extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device) + + proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10) + state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20) + # ensure float32 + state = state.float() + if state.dim() == 1: + state = state.unsqueeze(0) + + processed_obs[OBS_STATE] = state + return processed_obs + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Transforms feature keys from the LIBERO format to the LeRobot standard. + """ + new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {} + + # copy over non-STATE features + for ft, feats in features.items(): + if ft != PipelineFeatureType.STATE: + new_features[ft] = feats.copy() + + # rebuild STATE features + state_feats = {} + + # add our new flattened state + state_feats["observation.state"] = PolicyFeature( + key="observation.state", + shape=(20,), + dtype="float32", + ) + + new_features[PipelineFeatureType.STATE] = state_feats + + return new_features + + def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor: + """ + Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6). + + Args: + rot_mats (Tensor): Rotation matrices of shape (B, 3, 3) + + Returns: + Tensor: 6D rotation representation, shape (B, 6) + + Raises: + TypeError: if input is not a torch tensor + ValueError: if shape is not (B, 3, 3) + """ + + if not isinstance(rot_mats, torch.Tensor): + raise TypeError( + f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}" + ) + + if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3): + raise ValueError( + f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}" + ) + + rot_mats = rot_mats.to(torch.float32) + + col1 = rot_mats[:, :3, 0] # (B, 3) + col2 = rot_mats[:, :3, 1] # (B, 3) + + rot6d = torch.cat([col1, col2], dim=-1) # (B, 6) + + return rot6d + + def observation(self, observation): + return self._process_observation(observation) + @dataclass @@ -254,3 +376,24 @@ class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep): return { "expected_action_dim": self.expected_action_dim, } + +def make_xvla_libero_pre_post_processors( +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Build the LeRobot processor pipelines for XVLA with LIBERO environment. + """ + pre_processor_steps: list[ProcessorStep] = [] + post_processor_steps: list[ProcessorStep] = [] + pre_processor_steps.extend([LiberoProcessorStep(), XVLAImageScaleProcessorStep(), XVLAAddDomainIdProcessorStep()]) + post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()]) + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=pre_processor_steps, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=post_processor_steps, + ), + ) diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 87d8ae267..d2e4e5405 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -328,7 +328,6 @@ class _NormalizationMixin: if norm_mode == NormalizationMode.IMAGENET: mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype) std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype) - # Expand mean/std to match tensor dims (e.g., BCHW or BNCHW) while mean.dim() < tensor.dim(): mean = mean.unsqueeze(0) diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 4cf9c4095..fed13501f 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -175,11 +175,9 @@ def rollout( with torch.inference_mode(): action = policy.select_action(observation) action = postprocessor(action) - action_transition = {"action": action} action_transition = env_postprocessor(action_transition) action = action_transition["action"] - # Convert to CPU / numpy. action_numpy: np.ndarray = action.to("cpu").numpy() assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)" @@ -533,7 +531,7 @@ def eval_main(cfg: EvalPipelineConfig): ) # Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments) - env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env) + env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy) with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): info = eval_policy_all(