more fixes

This commit is contained in:
jade.choghari@huggingface.co
2025-11-18 14:24:59 +01:00
parent 9a115c303c
commit b4b5d057b1
2 changed files with 50 additions and 31 deletions
+36 -17
View File
@@ -230,7 +230,14 @@ class LiberoProcessorStep(ObservationProcessorStep):
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
@@ -241,8 +248,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
gripper_qpos = robot_state["gripper"]["qpos"] # (2,)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat.squeeze(0)) # (3,)
eef_axisangle = eef_axisangle[np.newaxis, :] # (1, 3)
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into a single state vector
state = np.concatenate((eef_pos, eef_axisangle, gripper_qpos), axis=1)
@@ -269,27 +275,40 @@ class LiberoProcessorStep(ObservationProcessorStep):
def _quat2axisangle(self, quat):
"""
# Copied from robosuite.utils.transform_utils.quat2axisangle
Converts quaternion to axis-angle format.
Converts quaternion to axis-angle format (vectorized for batches).
Returns a unit vector direction scaled by its angle in radians.
Args:
quat (np.array): (x,y,z,w) vec4 float angles
quat (np.array): (B, 4) or (4,) array of quaternions in (x,y,z,w) format
Returns:
np.array: (ax,ay,az) axis-angle exponential coordinates
np.array: (B, 3) or (3,) axis-angle exponential coordinates
"""
import math
# Handle both batched and single quaternion inputs
if quat.ndim == 1:
quat = quat[np.newaxis, :] # (1, 4)
single_input = True
else:
single_input = False
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
# clip quaternion w component to [-1, 1]
quat = quat.copy()
quat[:, 3] = np.clip(quat[:, 3], -1.0, 1.0)
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
# compute denominator - sqrt(1 - w^2)
den = np.sqrt(1.0 - quat[:, 3] ** 2)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
# for near-zero rotations, return zeros
result = np.zeros((quat.shape[0], 3))
# only compute for non-zero rotations
non_zero_mask = den > 1e-10
if np.any(non_zero_mask):
result[non_zero_mask] = (
quat[non_zero_mask, :3]
* (2.0 * np.arccos(quat[non_zero_mask, 3]) / den[non_zero_mask])[:, np.newaxis]
)
if single_input:
return result[0]
return result
+14 -14
View File
@@ -22,31 +22,31 @@ import torch
seed = 42
np.random.seed(seed)
B = 5
obs1 = {
"pixels": {
"image": (np.random.rand(1, 256, 256, 3) * 255).astype(np.uint8),
"image2": (np.random.rand(1, 256, 256, 3) * 255).astype(np.uint8),
"image": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
"image2": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
},
"robot_state": {
"eef": {
"pos": np.random.randn(1, 3),
"quat": np.random.randn(1, 4),
"mat": np.random.randn(1, 3, 3),
"axisangle": np.random.randn(1, 3),
"pos": np.random.randn(B, 3),
"quat": np.random.randn(B, 4),
"mat": np.random.randn(B, 3, 3),
"axisangle": np.random.randn(B, 3),
},
"gripper": {
"qpos": np.random.randn(1, 2),
"qvel": np.random.randn(1, 2),
"qpos": np.random.randn(B, 2),
"qvel": np.random.randn(B, 2),
},
"joints": {
"pos": np.random.randn(1, 7),
"vel": np.random.randn(1, 7),
"pos": np.random.randn(B, 7),
"vel": np.random.randn(B, 7),
}
}
}
observation = preprocess_observation(obs1)
libero_preprocessor = PolicyProcessorPipeline(
steps=[
LiberoProcessorStep(),
@@ -58,7 +58,7 @@ state = processed_obs["observation.state"]
assert isinstance(state, torch.Tensor)
assert state.dtype == torch.float32
assert state.shape[0] == 1
assert state.shape[0] == B
assert state.shape[1] == 8
assert "observation.images.image" in processed_obs
@@ -67,5 +67,5 @@ assert "observation.images.image2" in processed_obs
assert isinstance(processed_obs["observation.images.image"], torch.Tensor)
assert isinstance(processed_obs["observation.images.image2"], torch.Tensor)
assert processed_obs["observation.images.image"].shape == (1, 3, 256, 256)
assert processed_obs["observation.images.image2"].shape == (1, 3, 256, 256)
assert processed_obs["observation.images.image"].shape == (B, 3, 256, 256)
assert processed_obs["observation.images.image2"].shape == (B, 3, 256, 256)