📝 update libero state & action notation (#71)

This commit is contained in:
Qizhi Chen
2025-10-29 16:51:15 +08:00
committed by GitHub
parent e2db7df495
commit 940120d7ed
2 changed files with 15 additions and 4 deletions
+3 -3
View File
@@ -12,12 +12,12 @@ LIBERO_FEATURES = {
"observation.state": {
"dtype": "float32",
"shape": (8,),
"names": {"motors": ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"]},
"names": {"motors": ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3", "gripper", "gripper"]},
},
"observation.states.ee_state": {
"dtype": "float32",
"shape": (6,),
"names": {"motors": ["x", "y", "z", "roll", "pitch", "yaw"]},
"names": {"motors": ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3"]},
},
"observation.states.joint_state": {
"dtype": "float32",
@@ -32,6 +32,6 @@ LIBERO_FEATURES = {
"action": {
"dtype": "float32",
"shape": (7,),
"names": {"motors": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]},
"names": {"motors": ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3", "gripper"]},
},
}
+12 -1
View File
@@ -89,7 +89,16 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
if state_encoding == StateEncoding.POS_EULER:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
if "libero" in dataset_name:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
state_names = [
"x",
"y",
"z",
"axis_angle1",
"axis_angle2",
"axis_angle3",
"gripper",
"gripper",
] # 2D gripper state
elif state_encoding == StateEncoding.POS_QUAT:
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
elif state_encoding == StateEncoding.JOINT:
@@ -104,6 +113,8 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
action_encoding = OXE_DATASET_CONFIGS[dataset_name]["action_encoding"]
if action_encoding == ActionEncoding.EEF_POS:
action_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]
if "libero" in dataset_name:
action_names = ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3", "gripper"]
elif action_encoding == ActionEncoding.JOINT_POS:
action_names = [f"motor_{i}" for i in range(7)] + ["gripper"]