diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index ed776f37c..982bf9bf8 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -230,7 +230,14 @@ class LiberoProcessorStep(ObservationProcessorStep): Processes both image and robot_state observations from LIBERO. """ processed_obs = observation.copy() + for key in list(processed_obs.keys()): + if key.startswith(f"{OBS_IMAGES}."): + img = processed_obs[key] + # Flip both H and W + img = torch.flip(img, dims=[2, 3]) + + processed_obs[key] = img # Process robot_state into a flat state vector if "observation.robot_state" in processed_obs: robot_state = processed_obs.pop("observation.robot_state") @@ -241,8 +248,7 @@ class LiberoProcessorStep(ObservationProcessorStep): gripper_qpos = robot_state["gripper"]["qpos"] # (2,) # Convert quaternion to axis-angle - eef_axisangle = self._quat2axisangle(eef_quat.squeeze(0)) # (3,) - eef_axisangle = eef_axisangle[np.newaxis, :] # (1, 3) + eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3) # Concatenate into a single state vector state = np.concatenate((eef_pos, eef_axisangle, gripper_qpos), axis=1) @@ -269,27 +275,40 @@ class LiberoProcessorStep(ObservationProcessorStep): def _quat2axisangle(self, quat): """ - # Copied from robosuite.utils.transform_utils.quat2axisangle - Converts quaternion to axis-angle format. + Converts quaternion to axis-angle format (vectorized for batches). Returns a unit vector direction scaled by its angle in radians. Args: - quat (np.array): (x,y,z,w) vec4 float angles + quat (np.array): (B, 4) or (4,) array of quaternions in (x,y,z,w) format Returns: - np.array: (ax,ay,az) axis-angle exponential coordinates + np.array: (B, 3) or (3,) axis-angle exponential coordinates """ - import math + # Handle both batched and single quaternion inputs + if quat.ndim == 1: + quat = quat[np.newaxis, :] # (1, 4) + single_input = True + else: + single_input = False - # clip quaternion - if quat[3] > 1.0: - quat[3] = 1.0 - elif quat[3] < -1.0: - quat[3] = -1.0 + # clip quaternion w component to [-1, 1] + quat = quat.copy() + quat[:, 3] = np.clip(quat[:, 3], -1.0, 1.0) - den = np.sqrt(1.0 - quat[3] * quat[3]) - if math.isclose(den, 0.0): - # This is (close to) a zero degree rotation, immediately return - return np.zeros(3) + # compute denominator - sqrt(1 - w^2) + den = np.sqrt(1.0 - quat[:, 3] ** 2) - return (quat[:3] * 2.0 * math.acos(quat[3])) / den + # for near-zero rotations, return zeros + result = np.zeros((quat.shape[0], 3)) + + # only compute for non-zero rotations + non_zero_mask = den > 1e-10 + if np.any(non_zero_mask): + result[non_zero_mask] = ( + quat[non_zero_mask, :3] + * (2.0 * np.arccos(quat[non_zero_mask, 3]) / den[non_zero_mask])[:, np.newaxis] + ) + + if single_input: + return result[0] + return result diff --git a/tests/processor/test_libero_processor.py b/tests/processor/test_libero_processor.py index 7717490b1..fa28e850d 100644 --- a/tests/processor/test_libero_processor.py +++ b/tests/processor/test_libero_processor.py @@ -22,31 +22,31 @@ import torch seed = 42 np.random.seed(seed) +B = 5 obs1 = { "pixels": { - "image": (np.random.rand(1, 256, 256, 3) * 255).astype(np.uint8), - "image2": (np.random.rand(1, 256, 256, 3) * 255).astype(np.uint8), + "image": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8), + "image2": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8), }, "robot_state": { "eef": { - "pos": np.random.randn(1, 3), - "quat": np.random.randn(1, 4), - "mat": np.random.randn(1, 3, 3), - "axisangle": np.random.randn(1, 3), + "pos": np.random.randn(B, 3), + "quat": np.random.randn(B, 4), + "mat": np.random.randn(B, 3, 3), + "axisangle": np.random.randn(B, 3), }, "gripper": { - "qpos": np.random.randn(1, 2), - "qvel": np.random.randn(1, 2), + "qpos": np.random.randn(B, 2), + "qvel": np.random.randn(B, 2), }, "joints": { - "pos": np.random.randn(1, 7), - "vel": np.random.randn(1, 7), + "pos": np.random.randn(B, 7), + "vel": np.random.randn(B, 7), } } } observation = preprocess_observation(obs1) - libero_preprocessor = PolicyProcessorPipeline( steps=[ LiberoProcessorStep(), @@ -58,7 +58,7 @@ state = processed_obs["observation.state"] assert isinstance(state, torch.Tensor) assert state.dtype == torch.float32 -assert state.shape[0] == 1 +assert state.shape[0] == B assert state.shape[1] == 8 assert "observation.images.image" in processed_obs @@ -67,5 +67,5 @@ assert "observation.images.image2" in processed_obs assert isinstance(processed_obs["observation.images.image"], torch.Tensor) assert isinstance(processed_obs["observation.images.image2"], torch.Tensor) -assert processed_obs["observation.images.image"].shape == (1, 3, 256, 256) -assert processed_obs["observation.images.image2"].shape == (1, 3, 256, 256) +assert processed_obs["observation.images.image"].shape == (B, 3, 256, 256) +assert processed_obs["observation.images.image2"].shape == (B, 3, 256, 256)