diff --git a/openx_rlds.py b/openx_rlds.py index cf873a8..060a714 100644 --- a/openx_rlds.py +++ b/openx_rlds.py @@ -190,7 +190,12 @@ def create_lerobot_dataset( builder = tfds.builder(dataset_name, data_dir=data_dir, version=version) features = generate_features_from_raw(builder, use_videos) - raw_dataset = builder.as_dataset(split="train").map(partial(transform_raw_dataset, dataset_name=dataset_name)) + filter_fn = lambda e: e["success"] if dataset_name == "kuka" else True + raw_dataset = ( + builder.as_dataset(split="train") + .filter(filter_fn) + .map(partial(transform_raw_dataset, dataset_name=dataset_name)) + ) if fps is None: if dataset_name in OXE_DATASET_CONFIGS: diff --git a/oxe_utils/transforms.py b/oxe_utils/transforms.py index c845fc6..31a09a6 100644 --- a/oxe_utils/transforms.py +++ b/oxe_utils/transforms.py @@ -192,19 +192,6 @@ def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: ), axis=-1, ) - # decode compressed state - eef_value = tf.io.decode_compressed( - trajectory["observation"]["clip_function_input/base_pose_tool_reached"], - compression_type="ZLIB", - ) - eef_value = tf.io.decode_raw(eef_value, tf.float32) - trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7)) - gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB") - gripper_value = tf.io.decode_raw(gripper_value, tf.float32) - trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1)) - # trajectory["language_instruction"] = tf.fill( - # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" - # ) # delete uninformative language instruction trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] return trajectory