mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
more fixes
This commit is contained in:
@@ -230,7 +230,14 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
Processes both image and robot_state observations from LIBERO.
|
||||
"""
|
||||
processed_obs = observation.copy()
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
img = processed_obs[key]
|
||||
|
||||
# Flip both H and W
|
||||
img = torch.flip(img, dims=[2, 3])
|
||||
|
||||
processed_obs[key] = img
|
||||
# Process robot_state into a flat state vector
|
||||
if "observation.robot_state" in processed_obs:
|
||||
robot_state = processed_obs.pop("observation.robot_state")
|
||||
@@ -241,8 +248,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
gripper_qpos = robot_state["gripper"]["qpos"] # (2,)
|
||||
|
||||
# Convert quaternion to axis-angle
|
||||
eef_axisangle = self._quat2axisangle(eef_quat.squeeze(0)) # (3,)
|
||||
eef_axisangle = eef_axisangle[np.newaxis, :] # (1, 3)
|
||||
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
|
||||
|
||||
# Concatenate into a single state vector
|
||||
state = np.concatenate((eef_pos, eef_axisangle, gripper_qpos), axis=1)
|
||||
@@ -269,27 +275,40 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
def _quat2axisangle(self, quat):
|
||||
"""
|
||||
# Copied from robosuite.utils.transform_utils.quat2axisangle
|
||||
Converts quaternion to axis-angle format.
|
||||
Converts quaternion to axis-angle format (vectorized for batches).
|
||||
Returns a unit vector direction scaled by its angle in radians.
|
||||
|
||||
Args:
|
||||
quat (np.array): (x,y,z,w) vec4 float angles
|
||||
quat (np.array): (B, 4) or (4,) array of quaternions in (x,y,z,w) format
|
||||
|
||||
Returns:
|
||||
np.array: (ax,ay,az) axis-angle exponential coordinates
|
||||
np.array: (B, 3) or (3,) axis-angle exponential coordinates
|
||||
"""
|
||||
import math
|
||||
# Handle both batched and single quaternion inputs
|
||||
if quat.ndim == 1:
|
||||
quat = quat[np.newaxis, :] # (1, 4)
|
||||
single_input = True
|
||||
else:
|
||||
single_input = False
|
||||
|
||||
# clip quaternion
|
||||
if quat[3] > 1.0:
|
||||
quat[3] = 1.0
|
||||
elif quat[3] < -1.0:
|
||||
quat[3] = -1.0
|
||||
# clip quaternion w component to [-1, 1]
|
||||
quat = quat.copy()
|
||||
quat[:, 3] = np.clip(quat[:, 3], -1.0, 1.0)
|
||||
|
||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||
if math.isclose(den, 0.0):
|
||||
# This is (close to) a zero degree rotation, immediately return
|
||||
return np.zeros(3)
|
||||
# compute denominator - sqrt(1 - w^2)
|
||||
den = np.sqrt(1.0 - quat[:, 3] ** 2)
|
||||
|
||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||
# for near-zero rotations, return zeros
|
||||
result = np.zeros((quat.shape[0], 3))
|
||||
|
||||
# only compute for non-zero rotations
|
||||
non_zero_mask = den > 1e-10
|
||||
if np.any(non_zero_mask):
|
||||
result[non_zero_mask] = (
|
||||
quat[non_zero_mask, :3]
|
||||
* (2.0 * np.arccos(quat[non_zero_mask, 3]) / den[non_zero_mask])[:, np.newaxis]
|
||||
)
|
||||
|
||||
if single_input:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
@@ -22,31 +22,31 @@ import torch
|
||||
seed = 42
|
||||
np.random.seed(seed)
|
||||
|
||||
B = 5
|
||||
obs1 = {
|
||||
"pixels": {
|
||||
"image": (np.random.rand(1, 256, 256, 3) * 255).astype(np.uint8),
|
||||
"image2": (np.random.rand(1, 256, 256, 3) * 255).astype(np.uint8),
|
||||
"image": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
|
||||
"image2": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
|
||||
},
|
||||
"robot_state": {
|
||||
"eef": {
|
||||
"pos": np.random.randn(1, 3),
|
||||
"quat": np.random.randn(1, 4),
|
||||
"mat": np.random.randn(1, 3, 3),
|
||||
"axisangle": np.random.randn(1, 3),
|
||||
"pos": np.random.randn(B, 3),
|
||||
"quat": np.random.randn(B, 4),
|
||||
"mat": np.random.randn(B, 3, 3),
|
||||
"axisangle": np.random.randn(B, 3),
|
||||
},
|
||||
"gripper": {
|
||||
"qpos": np.random.randn(1, 2),
|
||||
"qvel": np.random.randn(1, 2),
|
||||
"qpos": np.random.randn(B, 2),
|
||||
"qvel": np.random.randn(B, 2),
|
||||
},
|
||||
"joints": {
|
||||
"pos": np.random.randn(1, 7),
|
||||
"vel": np.random.randn(1, 7),
|
||||
"pos": np.random.randn(B, 7),
|
||||
"vel": np.random.randn(B, 7),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
observation = preprocess_observation(obs1)
|
||||
|
||||
libero_preprocessor = PolicyProcessorPipeline(
|
||||
steps=[
|
||||
LiberoProcessorStep(),
|
||||
@@ -58,7 +58,7 @@ state = processed_obs["observation.state"]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.dtype == torch.float32
|
||||
|
||||
assert state.shape[0] == 1
|
||||
assert state.shape[0] == B
|
||||
assert state.shape[1] == 8
|
||||
|
||||
assert "observation.images.image" in processed_obs
|
||||
@@ -67,5 +67,5 @@ assert "observation.images.image2" in processed_obs
|
||||
assert isinstance(processed_obs["observation.images.image"], torch.Tensor)
|
||||
assert isinstance(processed_obs["observation.images.image2"], torch.Tensor)
|
||||
|
||||
assert processed_obs["observation.images.image"].shape == (1, 3, 256, 256)
|
||||
assert processed_obs["observation.images.image2"].shape == (1, 3, 256, 256)
|
||||
assert processed_obs["observation.images.image"].shape == (B, 3, 256, 256)
|
||||
assert processed_obs["observation.images.image2"].shape == (B, 3, 256, 256)
|
||||
|
||||
Reference in New Issue
Block a user