fix bugs in feature shape

This commit is contained in:
Tavish
2025-02-20 20:54:24 +08:00
parent d2456c1506
commit 6e7bdc3e9e
+3 -4
View File
@@ -89,7 +89,7 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
if state_encoding == StateEncoding.POS_EULER:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
if "libero" in dataset_name:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
elif state_encoding == StateEncoding.POS_QUAT:
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
elif state_encoding == StateEncoding.JOINT:
@@ -107,16 +107,15 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
elif action_encoding == ActionEncoding.JOINT_POS:
action_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
DEFAULT_FEATURES = {
"observation.state": {
"dtype": "float32",
"shape": (8,),
"shape": (len(state_names),),
"names": {"motors": state_names},
},
"action": {
"dtype": "float32",
"shape": (7,),
"shape": (len(action_names),),
"names": {"motors": action_names},
},
}