From 43fe83852cde3b4c7ec38761eb5c208f3b5c9da7 Mon Sep 17 00:00:00 2001 From: Tavish Date: Fri, 21 Feb 2025 19:51:27 +0800 Subject: [PATCH] fix kuka_dataset_transform & filter --- openx_rlds.py | 7 ++++++- oxe_utils/transforms.py | 13 ------------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/openx_rlds.py b/openx_rlds.py index cf873a8..060a714 100644 --- a/openx_rlds.py +++ b/openx_rlds.py @@ -190,7 +190,12 @@ def create_lerobot_dataset( builder = tfds.builder(dataset_name, data_dir=data_dir, version=version) features = generate_features_from_raw(builder, use_videos) - raw_dataset = builder.as_dataset(split="train").map(partial(transform_raw_dataset, dataset_name=dataset_name)) + filter_fn = lambda e: e["success"] if dataset_name == "kuka" else True + raw_dataset = ( + builder.as_dataset(split="train") + .filter(filter_fn) + .map(partial(transform_raw_dataset, dataset_name=dataset_name)) + ) if fps is None: if dataset_name in OXE_DATASET_CONFIGS: diff --git a/oxe_utils/transforms.py b/oxe_utils/transforms.py index c845fc6..31a09a6 100644 --- a/oxe_utils/transforms.py +++ b/oxe_utils/transforms.py @@ -192,19 +192,6 @@ def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: ), axis=-1, ) - # decode compressed state - eef_value = tf.io.decode_compressed( - trajectory["observation"]["clip_function_input/base_pose_tool_reached"], - compression_type="ZLIB", - ) - eef_value = tf.io.decode_raw(eef_value, tf.float32) - trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7)) - gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB") - gripper_value = tf.io.decode_raw(gripper_value, tf.float32) - trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1)) - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" - # ) # delete uninformative language instruction trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] return trajectory