mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
more fixes
This commit is contained in:
@@ -230,7 +230,14 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
|||||||
Processes both image and robot_state observations from LIBERO.
|
Processes both image and robot_state observations from LIBERO.
|
||||||
"""
|
"""
|
||||||
processed_obs = observation.copy()
|
processed_obs = observation.copy()
|
||||||
|
for key in list(processed_obs.keys()):
|
||||||
|
if key.startswith(f"{OBS_IMAGES}."):
|
||||||
|
img = processed_obs[key]
|
||||||
|
|
||||||
|
# Flip both H and W
|
||||||
|
img = torch.flip(img, dims=[2, 3])
|
||||||
|
|
||||||
|
processed_obs[key] = img
|
||||||
# Process robot_state into a flat state vector
|
# Process robot_state into a flat state vector
|
||||||
if "observation.robot_state" in processed_obs:
|
if "observation.robot_state" in processed_obs:
|
||||||
robot_state = processed_obs.pop("observation.robot_state")
|
robot_state = processed_obs.pop("observation.robot_state")
|
||||||
@@ -241,8 +248,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
|||||||
gripper_qpos = robot_state["gripper"]["qpos"] # (2,)
|
gripper_qpos = robot_state["gripper"]["qpos"] # (2,)
|
||||||
|
|
||||||
# Convert quaternion to axis-angle
|
# Convert quaternion to axis-angle
|
||||||
eef_axisangle = self._quat2axisangle(eef_quat.squeeze(0)) # (3,)
|
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
|
||||||
eef_axisangle = eef_axisangle[np.newaxis, :] # (1, 3)
|
|
||||||
|
|
||||||
# Concatenate into a single state vector
|
# Concatenate into a single state vector
|
||||||
state = np.concatenate((eef_pos, eef_axisangle, gripper_qpos), axis=1)
|
state = np.concatenate((eef_pos, eef_axisangle, gripper_qpos), axis=1)
|
||||||
@@ -269,27 +275,40 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
|||||||
|
|
||||||
def _quat2axisangle(self, quat):
|
def _quat2axisangle(self, quat):
|
||||||
"""
|
"""
|
||||||
# Copied from robosuite.utils.transform_utils.quat2axisangle
|
Converts quaternion to axis-angle format (vectorized for batches).
|
||||||
Converts quaternion to axis-angle format.
|
|
||||||
Returns a unit vector direction scaled by its angle in radians.
|
Returns a unit vector direction scaled by its angle in radians.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
quat (np.array): (x,y,z,w) vec4 float angles
|
quat (np.array): (B, 4) or (4,) array of quaternions in (x,y,z,w) format
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
np.array: (ax,ay,az) axis-angle exponential coordinates
|
np.array: (B, 3) or (3,) axis-angle exponential coordinates
|
||||||
"""
|
"""
|
||||||
import math
|
# Handle both batched and single quaternion inputs
|
||||||
|
if quat.ndim == 1:
|
||||||
|
quat = quat[np.newaxis, :] # (1, 4)
|
||||||
|
single_input = True
|
||||||
|
else:
|
||||||
|
single_input = False
|
||||||
|
|
||||||
# clip quaternion
|
# clip quaternion w component to [-1, 1]
|
||||||
if quat[3] > 1.0:
|
quat = quat.copy()
|
||||||
quat[3] = 1.0
|
quat[:, 3] = np.clip(quat[:, 3], -1.0, 1.0)
|
||||||
elif quat[3] < -1.0:
|
|
||||||
quat[3] = -1.0
|
|
||||||
|
|
||||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
# compute denominator - sqrt(1 - w^2)
|
||||||
if math.isclose(den, 0.0):
|
den = np.sqrt(1.0 - quat[:, 3] ** 2)
|
||||||
# This is (close to) a zero degree rotation, immediately return
|
|
||||||
return np.zeros(3)
|
|
||||||
|
|
||||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
# for near-zero rotations, return zeros
|
||||||
|
result = np.zeros((quat.shape[0], 3))
|
||||||
|
|
||||||
|
# only compute for non-zero rotations
|
||||||
|
non_zero_mask = den > 1e-10
|
||||||
|
if np.any(non_zero_mask):
|
||||||
|
result[non_zero_mask] = (
|
||||||
|
quat[non_zero_mask, :3]
|
||||||
|
* (2.0 * np.arccos(quat[non_zero_mask, 3]) / den[non_zero_mask])[:, np.newaxis]
|
||||||
|
)
|
||||||
|
|
||||||
|
if single_input:
|
||||||
|
return result[0]
|
||||||
|
return result
|
||||||
|
|||||||
@@ -22,31 +22,31 @@ import torch
|
|||||||
seed = 42
|
seed = 42
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
B = 5
|
||||||
obs1 = {
|
obs1 = {
|
||||||
"pixels": {
|
"pixels": {
|
||||||
"image": (np.random.rand(1, 256, 256, 3) * 255).astype(np.uint8),
|
"image": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
|
||||||
"image2": (np.random.rand(1, 256, 256, 3) * 255).astype(np.uint8),
|
"image2": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
|
||||||
},
|
},
|
||||||
"robot_state": {
|
"robot_state": {
|
||||||
"eef": {
|
"eef": {
|
||||||
"pos": np.random.randn(1, 3),
|
"pos": np.random.randn(B, 3),
|
||||||
"quat": np.random.randn(1, 4),
|
"quat": np.random.randn(B, 4),
|
||||||
"mat": np.random.randn(1, 3, 3),
|
"mat": np.random.randn(B, 3, 3),
|
||||||
"axisangle": np.random.randn(1, 3),
|
"axisangle": np.random.randn(B, 3),
|
||||||
},
|
},
|
||||||
"gripper": {
|
"gripper": {
|
||||||
"qpos": np.random.randn(1, 2),
|
"qpos": np.random.randn(B, 2),
|
||||||
"qvel": np.random.randn(1, 2),
|
"qvel": np.random.randn(B, 2),
|
||||||
},
|
},
|
||||||
"joints": {
|
"joints": {
|
||||||
"pos": np.random.randn(1, 7),
|
"pos": np.random.randn(B, 7),
|
||||||
"vel": np.random.randn(1, 7),
|
"vel": np.random.randn(B, 7),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
observation = preprocess_observation(obs1)
|
observation = preprocess_observation(obs1)
|
||||||
|
|
||||||
libero_preprocessor = PolicyProcessorPipeline(
|
libero_preprocessor = PolicyProcessorPipeline(
|
||||||
steps=[
|
steps=[
|
||||||
LiberoProcessorStep(),
|
LiberoProcessorStep(),
|
||||||
@@ -58,7 +58,7 @@ state = processed_obs["observation.state"]
|
|||||||
assert isinstance(state, torch.Tensor)
|
assert isinstance(state, torch.Tensor)
|
||||||
assert state.dtype == torch.float32
|
assert state.dtype == torch.float32
|
||||||
|
|
||||||
assert state.shape[0] == 1
|
assert state.shape[0] == B
|
||||||
assert state.shape[1] == 8
|
assert state.shape[1] == 8
|
||||||
|
|
||||||
assert "observation.images.image" in processed_obs
|
assert "observation.images.image" in processed_obs
|
||||||
@@ -67,5 +67,5 @@ assert "observation.images.image2" in processed_obs
|
|||||||
assert isinstance(processed_obs["observation.images.image"], torch.Tensor)
|
assert isinstance(processed_obs["observation.images.image"], torch.Tensor)
|
||||||
assert isinstance(processed_obs["observation.images.image2"], torch.Tensor)
|
assert isinstance(processed_obs["observation.images.image2"], torch.Tensor)
|
||||||
|
|
||||||
assert processed_obs["observation.images.image"].shape == (1, 3, 256, 256)
|
assert processed_obs["observation.images.image"].shape == (B, 3, 256, 256)
|
||||||
assert processed_obs["observation.images.image2"].shape == (1, 3, 256, 256)
|
assert processed_obs["observation.images.image2"].shape == (B, 3, 256, 256)
|
||||||
|
|||||||
Reference in New Issue
Block a user