diff --git a/openx_rlds.py b/openx_rlds.py index 96ff994..24346f9 100644 --- a/openx_rlds.py +++ b/openx_rlds.py @@ -89,7 +89,7 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo if state_encoding == StateEncoding.POS_EULER: state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"] if "libero" in dataset_name: - state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state + state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state elif state_encoding == StateEncoding.POS_QUAT: state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"] elif state_encoding == StateEncoding.JOINT: @@ -107,16 +107,15 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo elif action_encoding == ActionEncoding.JOINT_POS: action_names = [f"motor_{i}" for i in range(7)] + ["gripper"] - DEFAULT_FEATURES = { "observation.state": { "dtype": "float32", - "shape": (8,), + "shape": (len(state_names),), "names": {"motors": state_names}, }, "action": { "dtype": "float32", - "shape": (7,), + "shape": (len(action_names),), "names": {"motors": action_names}, }, }