fix kuka_dataset_transform & filter

This commit is contained in:
Tavish
2025-02-21 19:51:27 +08:00
parent a96845809b
commit 43fe83852c
2 changed files with 6 additions and 14 deletions
-13
View File
@@ -192,19 +192,6 @@ def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
),
axis=-1,
)
# decode compressed state
eef_value = tf.io.decode_compressed(
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
compression_type="ZLIB",
)
eef_value = tf.io.decode_raw(eef_value, tf.float32)
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB")
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
# trajectory["language_instruction"] = tf.fill(
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
# ) # delete uninformative language instruction
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory