diff --git a/examples/port_datasets/convert_rt1_example.sh b/examples/port_datasets/convert_rt1_example.sh new file mode 100644 index 000000000..d8ea10e15 --- /dev/null +++ b/examples/port_datasets/convert_rt1_example.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Example script for converting RT-1 dataset using SLURM +# Make sure to modify the paths and parameters according to your setup + +# Configuration +RAW_DIR="/path/to/datasets/fractal20220817_data/0.1.0" +REPO_ID="your_username/rt1_lerobot" +LOGS_DIR="/path/to/logs" +PARTITION="cpu" # Your SLURM partition name + +# Step 1: Convert dataset using distributed processing +echo "Starting RT-1 dataset conversion..." +python examples/port_datasets/slurm_port_shards.py \ + --raw-dir "$RAW_DIR" \ + --repo-id "$REPO_ID" \ + --dataset-type rlds \ + --logs-dir "$LOGS_DIR" \ + --job-name rt1_conversion \ + --workers 32 \ + --num-shards 32 \ + --partition "$PARTITION" \ + --cpus-per-task 4 \ + --mem-per-cpu 2G \ + --slurm 1 + +# Step 2: Wait for jobs to complete (you can monitor with squeue) +echo "Conversion jobs submitted. Monitor with 'squeue -u \$USER'" +echo "Once all jobs complete, run the aggregation step:" +echo "" +echo "python examples/port_datasets/slurm_aggregate_shards.py \\" +echo " --repo-id $REPO_ID \\" +echo " --push-to-hub" + +# Uncomment the following lines if you want to automatically aggregate +# (but make sure all shards are complete first) + +# echo "Waiting for jobs to complete..." +# while [ $(squeue -u $USER -h | wc -l) -gt 0 ]; do +# echo "Jobs still running, waiting 60 seconds..." +# sleep 60 +# done + +# echo "All jobs completed. Starting aggregation..." +# python examples/port_datasets/slurm_aggregate_shards.py \ +# --repo-id "$REPO_ID" \ +# --push-to-hub diff --git a/examples/port_datasets/oxe_utils/__init__.py b/examples/port_datasets/oxe_utils/__init__.py new file mode 100644 index 000000000..b1c802465 --- /dev/null +++ b/examples/port_datasets/oxe_utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Open X-Embodiment utilities for dataset conversion.""" diff --git a/examples/port_datasets/oxe_utils/configs.py b/examples/port_datasets/oxe_utils/configs.py new file mode 100644 index 000000000..d57be25a1 --- /dev/null +++ b/examples/port_datasets/oxe_utils/configs.py @@ -0,0 +1,854 @@ +""" +Adapt from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/oxe/configs.py +configs.py + +Defines per-dataset configuration (kwargs) for each dataset in Open-X Embodiment. + +Configuration adopts the following structure: + image_obs_keys: + primary: primary external RGB + secondary: secondary external RGB + wrist: wrist RGB + + depth_obs_keys: + primary: primary external depth + secondary: secondary external depth + wrist: wrist depth + + # Always 8-dim =>> changes based on `StateEncoding` + state_obs_keys: + StateEncoding.POS_EULER: EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper Open/Close (1) + StateEncoding.POS_QUAT: EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1) + StateEncoding.JOINT: Joint Angles (7, if fewer) + Gripper Open/Close (1) + + state_encoding: Type of `StateEncoding` + action_encoding: Type of action encoding (e.g., EEF Position vs. Joint Position) +""" + +from enum import IntEnum + +import tensorflow as tf + + +def zero_action_filter(traj: dict) -> bool: + """ + Filters transitions whose actions are all-0 (only relative actions, no gripper action). + Note: this filter is applied *after* action normalization, so need to compare to "normalized 0". + """ + DROID_Q01 = tf.convert_to_tensor( # NOQA: N806 + [ + -0.7776297926902771, + -0.5803514122962952, + -0.5795090794563293, + -0.6464047729969025, + -0.7041108310222626, + -0.8895104378461838, + ] + ) + DROID_Q99 = tf.convert_to_tensor( # NOQA: N806 + [ + 0.7597932070493698, + 0.5726242214441299, + 0.7351000607013702, + 0.6705610305070877, + 0.6464948207139969, + 0.8897542208433151, + ] + ) + DROID_NORM_0_ACT = ( # NOQA: N806 + 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1 + ) + + return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5) + + +# Defines Proprioceptive State Encoding Schemes +class StateEncoding(IntEnum): + # fmt: off + NONE = -1 # No Proprioceptive State + POS_EULER = 1 # EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper Open/Close (1) + POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1) + JOINT = 3 # Joint Angles (7, if fewer) + Gripper Open/Close (1) + JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ]) + # fmt: on + + +# Defines Action Encoding Schemes +class ActionEncoding(IntEnum): + # fmt: off + EEF_POS = 1 # EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1) + EEF_POS_QUAT = 5 # EEF Delta XYZ (3) + Quaternion (4) + Gripper Open/Close (1) + JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1) + JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ]) + EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1) + # fmt: on + + +# === Individual Dataset Configs === +OXE_DATASET_CONFIGS = { + "fractal20220817_data": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["base_pose_tool_reached", "gripper_closed"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 3, + "robot_type": "Google Robot", + }, + "kuka": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [ + "clip_function_input/base_pose_tool_reached", + "gripper_closed", + ], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Kuka iiwa", + }, + "bridge_oxe": { # Version of Bridge V2 in Open X-Embodiment mixture + "image_obs_keys": {"primary": "image", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "WidowX", + }, + "bridge_orig": { # Original version of Bridge V2 from project website + "image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "WidowX", + }, + "bridge_dataset": { # Original version of Bridge V2 from project website + "image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "WidowX", + }, + "taco_play": { + "image_obs_keys": { + "primary": "rgb_static", + "secondary": None, + "wrist": "rgb_gripper", + }, + "depth_obs_keys": { + "primary": "depth_static", + "secondary": None, + "wrist": "depth_gripper", + }, + "state_obs_keys": ["state_eef", None, "state_gripper"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 15, + "robot_type": "Franka", + }, + "jaco_play": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state_eef", None, "state_gripper"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Jaco 2", + }, + "berkeley_cable_routing": { + "image_obs_keys": { + "primary": "image", + "secondary": "top_image", + "wrist": "wrist45_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["robot_state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Franka", + }, + "roboturk": { + "image_obs_keys": {"primary": "front_rgb", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Sawyer", + }, + "nyu_door_opening_surprising_effectiveness": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 3, + "robot_type": "Hello Stretch", + }, + "viola": { + "image_obs_keys": { + "primary": "agentview_rgb", + "secondary": None, + "wrist": "eye_in_hand_rgb", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_states", "gripper_states"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, + "berkeley_autolab_ur5": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "hand_image", + }, + "depth_obs_keys": {"primary": "depth", "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "UR5", + }, + "toto": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 30, + "robot_type": "Franka", + }, + "language_table": { + "image_obs_keys": {"primary": "rgb", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["effector_translation", None, None, None, None, None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "xArm", + }, + "columbia_cairlab_pusht_real": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["robot_state", None, None, None, None, None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "UR5", + }, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["ee_position", "ee_orientation", None], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Kuka iiwa", + }, + "nyu_rot_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 3, + "robot_type": "xArm", + }, + "stanford_hydra_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Franka", + }, + "austin_buds_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, + "nyu_franka_play_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": "image_additional_view", + "wrist": None, + }, + "depth_obs_keys": { + "primary": "depth", + "secondary": "depth_additional_view", + "wrist": None, + }, + "state_obs_keys": ["eef_state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 3, + "robot_type": "Franka", + }, + "maniskill_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": { + "primary": "depth", + "secondary": None, + "wrist": "wrist_depth", + }, + "state_obs_keys": ["tcp_pose", "gripper_state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, + "furniture_bench_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Franka", + }, + "cmu_franka_exploration_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "highres_image", + "secondary": None, + "wrist": None, + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Franka", + }, + "ucsd_kitchen_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 2, + "robot_type": "xArm", + }, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 3, + "robot_type": "xArm", + }, + "austin_sailor_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, + "austin_sirius_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, + "bc_z": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [ + "present/xyz", + "present/axis_angle", + None, + "present/sensed_close", + ], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Google Robot", + }, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "PR2", + }, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "PR2", + }, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": "image2", + "wrist": "hand_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["end_effector_pose", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "xArm", + }, + "utokyo_xarm_bimanual_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["pose_r", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "xArm Bimanual", + }, + "robo_net": { + "image_obs_keys": {"primary": "image", "secondary": "image1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 1, + "robot_type": "Multi-Robot", + }, + "berkeley_mvp_converted_externally_to_rlds": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["pose", "gripper"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.JOINT_POS, + "control_frequency": 5, + "robot_type": "xArm", + }, + "berkeley_rpt_converted_externally_to_rlds": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_pos", "gripper"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.JOINT_POS, + "control_frequency": 30, + "robot_type": "Franka", + }, + "kaist_nonprehensile_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Franka", + }, + "stanford_mask_vit_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": None, + "robot_type": "Sawyer", + }, + "tokyo_u_lsmo_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Cobotta", + }, + "dlr_sara_pour_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "DLR SARA", + }, + "dlr_sara_grid_clamp_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "DLR SARA", + }, + "dlr_edan_shared_control_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "DLR EDAN", + }, + "asu_table_top_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 12.5, + "robot_type": "UR5", + }, + "stanford_robocook_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None}, + "depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "Franka", + }, + "imperialcollege_sawyer_wrist_cam": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, "state"], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Sawyer", + }, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", "gripper_state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, + "uiuc_d3field": { + "image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None}, + "depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 1, + "robot_type": "Kinova Gen3", + }, + "utaustin_mutex": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, + "berkeley_fanuc_manipulation": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", None, "gripper_state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Fanuc Mate", + }, + "cmu_playing_with_food": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "finger_vision_1", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Franka", + }, + "cmu_play_fusion": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "Franka", + }, + "cmu_stretch": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["eef_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Hello Stretch", + }, + "berkeley_gnm_recon": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 3, + "robot_type": "Jackal", + }, + "berkeley_gnm_cory_hall": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "RC Car", + }, + "berkeley_gnm_sac_son": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "TurtleBot 2", + }, + # NOTE: modified + "droid": { + "image_obs_keys": { + "primary": "exterior_image_1_left", + "secondary": "exterior_image_2_left", + "wrist": "wrist_image_left", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 15, + "robot_type": "Franka", + "aux_kwargs": { + "dataset_frame_transform_kwargs": { + "chunk_filter_fn": zero_action_filter, + }, + }, + }, + "fmb_dataset": { + "image_obs_keys": { + "primary": "image_side_1", + "secondary": "image_side_2", + "wrist": "image_wrist_1", + }, + "depth_obs_keys": { + "primary": "image_side_1_depth", + "secondary": "image_side_2_depth", + "wrist": "image_wrist_1_depth", + }, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Franka", + }, + # NOTE: modified + "dobbe": { + "image_obs_keys": {"primary": "wrist_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 3.75, + "robot_type": "Hello Stretch", + }, + "roboset": { + "image_obs_keys": { + "primary": "image_left", + "secondary": "image_right", + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.JOINT_POS, + "control_frequency": 5, + "robot_type": "Franka", + }, + "rh20t": { + "image_obs_keys": { + "primary": "image_front", + "secondary": "image_side_right", + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 10, + "robot_type": "Flexiv", + }, + ### T-DROID datasets + "tdroid_carrot_in_bowl": { # "put carrot in bowl" task, 50 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "Franka", + }, + "tdroid_pour_corn_in_pot": { # "pour corn from red bonawl into steel pot" task, 50 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "Franka", + }, + "tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "Franka", + }, + "tdroid_move_object_onto_plate": { # "move onto plate" task, 150 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "Franka", + }, + "tdroid_knock_object_over": { # "knock over" task, 70 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "Franka", + }, + "tdroid_cover_object_with_towel": { # "cover with towel" task, 45 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", None, "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 5, + "robot_type": "Franka", + }, + ### DROID Finetuning datasets + "droid_wipe": { + "image_obs_keys": { + "primary": "exterior_image_2_left", + "secondary": None, + "wrist": "wrist_image_left", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 15, + "robot_type": "Franka", + }, + # NOTE: modified + ### LIBERO datasets (modified versions) + "libero_spatial_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, + "libero_object_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, + "libero_goal_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, + "libero_10_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + "control_frequency": 20, + "robot_type": "Franka", + }, +} diff --git a/examples/port_datasets/oxe_utils/transform_utils.py b/examples/port_datasets/oxe_utils/transform_utils.py new file mode 100644 index 000000000..8133614fe --- /dev/null +++ b/examples/port_datasets/oxe_utils/transform_utils.py @@ -0,0 +1,76 @@ +""" +Copied from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/utils/data_utils.py +""" + +from typing import Any + +import tensorflow as tf + + +def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + """ + Converts gripper actions from continuous to binary values (0 and 1). + + We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it + transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate + values based on the state that is reached _after_ those intermediate values. + + In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that + chunk of intermediate values as the last action in the trajectory. + + The `scan_fn` implements the following logic: + new_actions = np.empty_like(actions) + carry = actions[-1] + for i in reversed(range(actions.shape[0])): + if in_between_mask[i]: + carry = carry + else: + carry = float(open_mask[i]) + new_actions[i] = carry + """ + open_mask, closed_mask = actions > 0.95, actions < 0.05 + in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask)) + is_open_float = tf.cast(open_mask, tf.float32) + + def scan_fn(carry, i): + return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i]) + + return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True) + + +def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + return 1 - actions + + +def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + """ + Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open). + + Assumes that the first relative gripper is not redundant (i.e. close when already closed)! + """ + # Note =>> -1 for closing, 1 for opening, 0 for no change + opening_mask, closing_mask = actions < -0.1, actions > 0.1 + thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0)) + + def scan_fn(carry, i): + return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i]) + + # If no relative grasp, assumes open for whole trajectory + start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)] + start = tf.cond(start == 0, lambda: 1, lambda: start) + + # Note =>> -1 for closed, 1 for open + new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start) + new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5 + + return new_actions + + +# === Bridge-V2 =>> Dataset-Specific Transform === +def relabel_bridge_actions(traj: dict[str, Any]) -> dict[str, Any]: + """Relabels actions to use reached proprioceptive state; discards last timestep (no-action).""" + movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6] + traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) + traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1) + + return traj_truncated diff --git a/examples/port_datasets/oxe_utils/transforms.py b/examples/port_datasets/oxe_utils/transforms.py new file mode 100644 index 000000000..be5189397 --- /dev/null +++ b/examples/port_datasets/oxe_utils/transforms.py @@ -0,0 +1,1006 @@ +""" +Adapt from https://github.com/openvla/openvla/blob/main/prismatic/vla/datasets/rlds/oxe/transforms.py +transforms.py + +Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment. + +Transforms adopt the following structure: + Input: Dictionary of *batched* features (i.e., has leading time dimension) + Output: Dictionary `step` =>> { + "observation": { + + State (in chosen state representation) + }, + "action": Action (in chosen action representation), + "language_instruction": str + } +""" + +from typing import Any + +import tensorflow as tf + +from oxe_utils.transform_utils import ( + binarize_gripper_actions, + invert_gripper_actions, + rel2abs_gripper_actions, + relabel_bridge_actions, +) + + +def droid_baseact_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *base* frame of the robot. + """ + + def rand_swap_exterior_images(img1, img2): + """ + Randomly swaps the two exterior images (for training with single exterior input). + """ + return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1)) + + dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] + dr = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] + + trajectory["action"] = tf.concat( + ( + dt, + dr, + 1 - trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( + rand_swap_exterior_images( + trajectory["observation"]["exterior_image_1_left"], + trajectory["observation"]["exterior_image_2_left"], + ) + ) + # trajectory["observation"]["proprio"] = tf.concat( + # ( + # trajectory["observation"]["cartesian_position"], + # trajectory["observation"]["gripper_position"], + # ), + # axis=-1, + # ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"] + return trajectory + + +def droid_finetuning_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *base* frame of the robot. + """ + dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] + dr = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] + trajectory["action"] = tf.concat( + ( + dt, + dr, + 1 - trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["cartesian_position"], + trajectory["observation"]["gripper_position"], + ), + axis=-1, + ) + return trajectory + + +def bridge_oxe_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + """ + Applies to version of Bridge V2 in Open X-Embodiment mixture. + + Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! + """ + for key in trajectory.keys(): + if key == "traj_metadata": + continue + elif key in ["observation", "action"]: + for key2 in trajectory[key]: + trajectory[key][key2] = trajectory[key][key2][1:] + else: + trajectory[key] = trajectory[key][1:] + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + trajectory = relabel_bridge_actions(trajectory) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def bridge_orig_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + """ + Applies to original version of Bridge V2 from the official project website. + + Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! + """ + + for key in trajectory.keys(): + if key == "traj_metadata": + continue + elif key == "observation": + for key2 in trajectory[key]: + trajectory[key][key2] = trajectory[key][key2][1:] + else: + trajectory[key] = trajectory[key][1:] + + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory = relabel_bridge_actions(trajectory) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def ppgm_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] + return trajectory + + +def rt1_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def kuka_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def taco_play_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6] + trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8] + trajectory["action"] = trajectory["action"]["rel_actions_world"] + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + tf.clip_by_value(trajectory["action"][:, -1:], 0, 1), + ), + axis=-1, + ) + + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def jaco_play_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6] + trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][ + :, -1: + ] + + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + tf.zeros_like(trajectory["action"]["world_vector"]), + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def berkeley_cable_routing_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.zeros_like(trajectory["action"]["world_vector"][:, :1]), + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def roboturk_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # invert absolute gripper action, +1 = open, 0 = close + gripper_action = invert_gripper_actions( + tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1) + ) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action, + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def nyu_door_opening_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def viola_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # make gripper action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, None] + gripper_action = tf.clip_by_value(gripper_action, 0, 1) + gripper_action = invert_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action, + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def berkeley_autolab_ur5_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # flip wrist_image from bgr to rgb + trajectory["observation"]["hand_image"] = trajectory["observation"]["hand_image"][..., ::-1] + + trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14] + trajectory["observation"]["depth"] = trajectory["observation"].pop("image_with_depth") + + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def toto_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def language_table_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # default to "open" gripper + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"]), + tf.ones_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + + # decode language instruction + instruction_bytes = trajectory["observation"]["instruction"] + instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8") + # Remove trailing padding --> convert RaggedTensor to regular Tensor. + trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[ + :, 0 + ] + return trajectory + + +def pusht_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + trajectory["action"]["gripper_closedness_action"][:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def stanford_kuka_multimodal_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tf.zeros_like(trajectory["action"][:, :3]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def nyu_rot_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:] + trajectory["action"] = trajectory["action"][..., :7] + return trajectory + + +def stanford_hydra_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # flip image & wrist_image from bgr to rgb + trajectory["observation"]["image"] = trajectory["observation"]["image"][..., ::-1] + trajectory["observation"]["wrist_image"] = trajectory["observation"]["wrist_image"][..., ::-1] + + # invert gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(trajectory["action"][:, -1:]), + ), + axis=-1, + ) + + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :3], + trajectory["observation"]["state"][:, 7:10], + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2] + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def austin_buds_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def nyu_franka_play_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32) + trajectory["observation"]["depth_additional_view"] = tf.cast( + trajectory["observation"]["depth_additional_view"][..., 0], tf.float32 + ) + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:] + + # clip gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, -8:-2], + tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def maniskill_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8] + return trajectory + + +def furniture_bench_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["observation"]["state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :7], + trajectory["observation"]["state"][:, -1:], + ), + axis=-1, + ) + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + return trajectory + + +def cmu_franka_exploration_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def ucsd_kitchen_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def ucsd_pick_place_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tf.zeros_like(trajectory["action"][:, :3]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def austin_sailor_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def austin_sirius_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def bc_z_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["future/xyz_residual"][:, :3], + trajectory["action"]["future/axis_angle_residual"][:, :3], + invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)), + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def tokyo_pr2_opening_fridge_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def utokyo_xarm_pick_place_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + return trajectory + + +def utokyo_xarm_bimanual_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = trajectory["action"][..., -7:] + return trajectory + + +def robo_net_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :4], + tf.zeros_like(trajectory["observation"]["state"][:, :2]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :4], + tf.zeros_like(trajectory["action"][:, :2]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def berkeley_mvp_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["gripper"] = trajectory["observation"]["gripper"][:, None] + return trajectory + + +def berkeley_rpt_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["gripper"] = trajectory["observation"]["gripper"][:, None] + return trajectory + + +def kaist_nonprehensible_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def stanford_mask_vit_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["end_effector_pose"][:, :4], + tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :4], + tf.zeros_like(trajectory["action"][:, :2]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def tokyo_lsmo_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def dlr_sara_pour_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + return trajectory + + +def dlr_sara_grid_clamp_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6] + return trajectory + + +def dlr_edan_shared_control_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # invert gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(trajectory["action"][:, -1:]), + ), + axis=-1, + ) + return trajectory + + +def asu_table_top_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def robocook_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def imperial_wristcam_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def iamlab_pick_insert_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + trajectory["action"][:, 7:8], + ), + axis=-1, + ) + return trajectory + + +def uiuc_d3field_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def utaustin_mutex_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # flip image & wrist_image from bgr to rgb + trajectory["observation"]["image"] = trajectory["observation"]["image"][..., ::-1] + trajectory["observation"]["wrist_image"] = trajectory["observation"]["wrist_image"][..., ::-1] + + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def berkeley_fanuc_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # flip image & wrist_image from bgr to rgb + trajectory["observation"]["image"] = trajectory["observation"]["image"][..., ::-1] + trajectory["observation"]["wrist_image"] = trajectory["observation"]["wrist_image"][..., ::-1] + + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7] + + # dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"], + invert_gripper_actions(trajectory["observation"]["gripper_state"]), + ), + axis=-1, + ) + return trajectory + + +def cmu_playing_with_food_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def playfusion_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + trajectory["action"][:, -4:], + ), + axis=-1, + ) + return trajectory + + +def cmu_stretch_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :3], + tf.zeros_like(trajectory["observation"]["state"][:, :3]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def gnm_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["observation"]["state"] = tf.concat( + ( + trajectory["observation"]["position"], + tf.zeros_like(trajectory["observation"]["state"][:, :3]), + trajectory["observation"]["yaw"], + ), + axis=-1, + ) + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def fmb_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # flip image from bgr to rgb + trajectory["observation"]["image_wrist_1"] = trajectory["observation"]["image_wrist_1"][..., ::-1] + trajectory["observation"]["image_wrist_2"] = trajectory["observation"]["image_wrist_2"][..., ::-1] + trajectory["observation"]["image_side_1"] = trajectory["observation"]["image_side_1"][..., ::-1] + trajectory["observation"]["image_side_2"] = trajectory["observation"]["image_side_2"][..., ::-1] + + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["eef_pose"], + trajectory["observation"]["state_gripper_pose"][..., None], + ), + axis=-1, + ) + return trajectory + + +def dobbe_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def roboset_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = trajectory["observation"]["state"] + + # gripper action is in -1...1 --> clip to 0...1, flip + gripper_action = trajectory["action"][:, -1:] + gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1)) + + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :7], + gripper_action, + ), + axis=-1, + ) + return trajectory + + +def rh20t_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["tcp_base"], + tf.cast(trajectory["action"]["gripper"][:, None], tf.float32), + ), + axis=-1, + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["tcp_base"], + trajectory["observation"]["gripper_width"][..., None], + ), + axis=-1, + ) + return trajectory + + +def tdroid_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] + return trajectory + + +def libero_dataset_transform(trajectory: dict[str, Any]) -> dict[str, Any]: + # gripper action is in -1 (open)...1 (close) --> clip to 0...1, flip --> +1 = open, 0 = close + gripper_action = trajectory["action"][:, -1:] + gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1)) + + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + gripper_action, + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][ + :, -2: + ] # 2D gripper state + return trajectory + + +# === Registry === +OXE_STANDARDIZATION_TRANSFORMS = { + "bridge_oxe": bridge_oxe_dataset_transform, + "bridge_orig": bridge_orig_dataset_transform, + "bridge_dataset": bridge_orig_dataset_transform, + "ppgm": ppgm_dataset_transform, + "ppgm_static": ppgm_dataset_transform, + "ppgm_wrist": ppgm_dataset_transform, + "fractal20220817_data": rt1_dataset_transform, + "kuka": kuka_dataset_transform, + "taco_play": taco_play_dataset_transform, + "jaco_play": jaco_play_dataset_transform, + "berkeley_cable_routing": berkeley_cable_routing_dataset_transform, + "roboturk": roboturk_dataset_transform, + "nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform, + "viola": viola_dataset_transform, + "berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform, + "toto": toto_dataset_transform, + "language_table": language_table_dataset_transform, + "columbia_cairlab_pusht_real": pusht_dataset_transform, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform, + "nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform, + "stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform, + "austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform, + "nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform, + "maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform, + "furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform, + "cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform, + "ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform, + "austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform, + "austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform, + "bc_z": bc_z_dataset_transform, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": utokyo_xarm_pick_place_dataset_transform, + "utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform, + "robo_net": robo_net_dataset_transform, + "berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform, + "berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform, + "kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform, + "stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform, + "tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform, + "dlr_sara_pour_converted_externally_to_rlds": dlr_sara_pour_dataset_transform, + "dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform, + "dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform, + "asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform, + "stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform, + "imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform, + "uiuc_d3field": uiuc_d3field_dataset_transform, + "utaustin_mutex": utaustin_mutex_dataset_transform, + "berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform, + "cmu_playing_with_food": cmu_playing_with_food_dataset_transform, + "cmu_play_fusion": playfusion_dataset_transform, + "cmu_stretch": cmu_stretch_dataset_transform, + "berkeley_gnm_recon": gnm_dataset_transform, + "berkeley_gnm_cory_hall": gnm_dataset_transform, + "berkeley_gnm_sac_son": gnm_dataset_transform, + "droid": droid_baseact_transform, + "fmb_dataset": fmb_dataset_transform, + "dobbe": dobbe_dataset_transform, + "roboset": roboset_dataset_transform, + "rh20t_rlds": rh20t_dataset_transform, + ### T-DROID datasets + "tdroid_carrot_in_bowl": tdroid_dataset_transform, + "tdroid_pour_corn_in_pot": tdroid_dataset_transform, + "tdroid_flip_pot_upright": tdroid_dataset_transform, + "tdroid_move_object_onto_plate": tdroid_dataset_transform, + "tdroid_knock_object_over": tdroid_dataset_transform, + "tdroid_cover_object_with_towel": tdroid_dataset_transform, + ### DROID Finetuning datasets + "droid_wipe": droid_finetuning_transform, + ### LIBERO datasets (modified versions) + "libero_spatial_no_noops": libero_dataset_transform, + "libero_object_no_noops": libero_dataset_transform, + "libero_goal_no_noops": libero_dataset_transform, + "libero_10_no_noops": libero_dataset_transform, +} diff --git a/examples/port_datasets/port_rlds.py b/examples/port_datasets/port_rlds.py new file mode 100644 index 000000000..43047d679 --- /dev/null +++ b/examples/port_datasets/port_rlds.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import re +import time +from functools import partial +from pathlib import Path +from typing import Any + +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +from oxe_utils.configs import OXE_DATASET_CONFIGS, ActionEncoding, StateEncoding +from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS + +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds + +# Default FPS for datasets without specific config +DEFAULT_FPS = 10 +DEFAULT_ROBOT_TYPE = "unknown" + + +def determine_dataset_info(raw_dir: Path): + """Determine dataset name and version from directory structure.""" + last_part = raw_dir.name + if re.match(r"^\d+\.\d+\.\d+$", last_part): + version = last_part + dataset_name = raw_dir.parent.name + data_dir = raw_dir.parent.parent + else: + version = "" + dataset_name = last_part + data_dir = raw_dir.parent + return dataset_name, version, data_dir + + +def generate_features_from_builder(builder: tfds.core.DatasetBuilder, dataset_name: str) -> dict[str, Any]: + """Generate LeRobot features schema from TFDS builder and dataset config.""" + + # Generate state names based on encoding type + state_names = [f"motor_{i}" for i in range(8)] + if dataset_name in OXE_DATASET_CONFIGS: + state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"] + if state_encoding == StateEncoding.POS_EULER: + state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"] + if "libero" in dataset_name: + state_names = [ + "x", + "y", + "z", + "roll", + "pitch", + "yaw", + "gripper", + "gripper", + ] # 2D gripper state + elif state_encoding == StateEncoding.POS_QUAT: + state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"] + elif state_encoding == StateEncoding.JOINT: + state_names = [f"motor_{i}" for i in range(7)] + ["gripper"] + state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"] + pad_count = state_obs_keys[:-1].count(None) + state_names[-pad_count - 1 : -1] = ["pad"] * pad_count + state_names[-1] = "pad" if state_obs_keys[-1] is None else state_names[-1] + + # Generate action names based on encoding type + action_names = [f"motor_{i}" for i in range(8)] + if dataset_name in OXE_DATASET_CONFIGS: + action_encoding = OXE_DATASET_CONFIGS[dataset_name]["action_encoding"] + if action_encoding == ActionEncoding.EEF_POS: + action_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"] + elif action_encoding == ActionEncoding.JOINT_POS: + action_names = [f"motor_{i}" for i in range(7)] + ["gripper"] + + # Base features (state and action) + features = { + "observation.state": { + "dtype": "float32", + "shape": (len(state_names),), + "names": {"axes": state_names}, + }, + "action": { + "dtype": "float32", + "shape": (len(action_names),), + "names": {"axes": action_names}, + }, + } + + # Add image features from TFDS builder info + obs_features = builder.info.features["steps"]["observation"] + for key, value in obs_features.items(): + # Skip depth images and non-image features + if "depth" in key or not any(x in key for x in ["image", "rgb"]): + continue + + features[f"observation.images.{key}"] = { + "dtype": "video", + "shape": tuple(value.shape), + "names": ["height", "width", "channels"], + } + + return features + + +def transform_raw_dataset(episode, dataset_name: str): + """Apply OXE standardization transforms to raw TFDS episode.""" + # Batch all steps in the episode + traj = next(iter(episode["steps"].batch(episode["steps"].cardinality()))) + + # Apply dataset-specific transform if available + if dataset_name in OXE_STANDARDIZATION_TRANSFORMS: + traj = OXE_STANDARDIZATION_TRANSFORMS[dataset_name](traj) + + # Create consolidated state vector + if dataset_name in OXE_DATASET_CONFIGS: + state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"] + else: + state_obs_keys = [None for _ in range(8)] + + # Build proprio (proprioceptive state) vector + proprio_components = [] + for key in state_obs_keys: + if key is None: + # Add padding for missing state components + component = tf.zeros((tf.shape(traj["action"])[0], 1), dtype=tf.float32) + else: + component = tf.cast(traj["observation"][key], tf.float32) + # Ensure component has right shape (add dimension if needed) + if len(component.shape) == 1: + component = component[:, None] + proprio_components.append(component) + + proprio = tf.concat(proprio_components, axis=1) + + # Update trajectory with standardized format + traj.update( + { + "proprio": proprio, + "task": traj.get("language_instruction", ""), + "action": tf.cast(traj["action"], tf.float32), + } + ) + + episode["steps"] = traj + return episode + + +def generate_lerobot_frames(tf_episode): + """Generate LeRobot frames from transformed TFDS episode.""" + traj = tf_episode["steps"] + + # Get the task/language instruction + if isinstance(traj["task"], tf.Tensor): + if traj["task"].dtype == tf.string: + task = traj["task"][0].numpy().decode() if len(traj["task"]) > 0 else "" + else: + task = str(traj["task"][0].numpy()) if len(traj["task"]) > 0 else "" + else: + task = str(traj["task"]) if traj["task"] else "" + + # Iterate through each timestep + num_steps = tf.shape(traj["action"])[0].numpy() + for i in range(num_steps): + frame = {} + + # Add observation state + frame["observation.state"] = traj["proprio"][i].numpy() + + # Add action + frame["action"] = traj["action"][i].numpy() + + # Add images + for key, value in traj["observation"].items(): + if any(x in key for x in ["image", "rgb"]) and "depth" not in key: + frame[f"observation.images.{key}"] = value[i].numpy() + + # Add task + frame["task"] = task + + # Cast fp64 to fp32 + for key in frame: + if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64: + frame[key] = frame[key].astype(np.float32) + + yield frame + + +def port_rlds( + raw_dir: Path, + repo_id: str, + push_to_hub: bool = False, + num_shards: int | None = None, + shard_index: int | None = None, +): + """Port RLDS dataset to LeRobot format.""" + + # Determine dataset info + dataset_name, version, data_dir = determine_dataset_info(raw_dir) + + # Build TFDS dataset + builder = tfds.builder( + f"{dataset_name}/{version}" if version else dataset_name, data_dir=data_dir, version=version + ) + + # Handle sharding if specified + if num_shards is not None and shard_index is not None: + if shard_index >= num_shards: + raise ValueError(f"Shard index {shard_index} >= num_shards {num_shards}") + + # Calculate shard splits + total_episodes = builder.info.splits["train"].num_examples + episodes_per_shard = total_episodes // num_shards + start_idx = shard_index * episodes_per_shard + if shard_index == num_shards - 1: + # Last shard gets remaining episodes + end_idx = total_episodes + else: + end_idx = start_idx + episodes_per_shard + + split_str = f"train[{start_idx}:{end_idx}]" + raw_dataset = builder.as_dataset(split=split_str) + else: + raw_dataset = builder.as_dataset(split="train") + + # Apply filtering (e.g., success filter for kuka) + if dataset_name == "kuka": + raw_dataset = raw_dataset.filter(lambda e: e["success"]) + + # Apply transformations + raw_dataset = raw_dataset.map(partial(transform_raw_dataset, dataset_name=dataset_name)) + + # Get dataset configuration + fps = DEFAULT_FPS + robot_type = DEFAULT_ROBOT_TYPE + + if dataset_name in OXE_DATASET_CONFIGS: + config = OXE_DATASET_CONFIGS[dataset_name] + fps = config.get("control_frequency", DEFAULT_FPS) + robot_type = config.get("robot_type", DEFAULT_ROBOT_TYPE) + robot_type = robot_type.lower().replace(" ", "_").replace("-", "_") + + # Generate features schema + features = generate_features_from_builder(builder, dataset_name) + + # Create LeRobot dataset + lerobot_dataset = LeRobotDataset.create( + repo_id=repo_id, + robot_type=robot_type, + fps=int(fps), + features=features, + ) + + # Process episodes + start_time = time.time() + num_episodes = raw_dataset.cardinality().numpy().item() + logging.info(f"Number of episodes: {num_episodes}") + + for episode_index, episode in enumerate(raw_dataset): + elapsed_time = time.time() - start_time + d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time) + + logging.info( + f"{episode_index} / {num_episodes} episodes processed " + f"(after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)" + ) + + # Generate and add frames + for frame in generate_lerobot_frames(episode): + lerobot_dataset.add_frame(frame) + + lerobot_dataset.save_episode() + logging.info("Save_episode") + + # Push to hub if requested + if push_to_hub: + tags = ["openx", dataset_name] + if robot_type != "unknown": + tags.append(robot_type) + + lerobot_dataset.push_to_hub( + tags=tags, + private=False, + ) + + +def validate_dataset(repo_id): + """Sanity check that ensures metadata can be loaded and all files are present.""" + meta = LeRobotDatasetMetadata(repo_id) + + if meta.total_episodes == 0: + raise ValueError("Number of episodes is 0.") + + for ep_idx in range(meta.total_episodes): + data_path = meta.root / meta.get_data_file_path(ep_idx) + + if not data_path.exists(): + raise ValueError(f"Parquet file is missing in: {data_path}") + + for vid_key in meta.video_keys: + vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key) + if not vid_path.exists(): + raise ValueError(f"Video file is missing in: {vid_path}") + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--raw-dir", + type=Path, + required=True, + help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).", + ) + parser.add_argument( + "--repo-id", + type=str, + help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True", + ) + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Upload to hub.", + ) + parser.add_argument( + "--num-shards", + type=int, + default=None, + help="Number of shards to split the dataset into for parallel processing.", + ) + parser.add_argument( + "--shard-index", + type=int, + default=None, + help="Index of the shard to process (0-indexed).", + ) + + args = parser.parse_args() + + port_rlds(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/examples/port_datasets/slurm_port_shards.py b/examples/port_datasets/slurm_port_shards.py index 3bb4c135c..3021d7e88 100644 --- a/examples/port_datasets/slurm_port_shards.py +++ b/examples/port_datasets/slurm_port_shards.py @@ -61,13 +61,71 @@ class PortDroidShards(PipelineStep): validate_dataset(shard_repo_id) +class PortRLDSShards(PipelineStep): + def __init__( + self, + raw_dir: Path | str, + repo_id: str = None, + num_shards: int = None, + ): + super().__init__() + self.raw_dir = Path(raw_dir) + self.repo_id = repo_id + self.num_shards = num_shards + + def run(self, data=None, rank: int = 0, world_size: int = 1): + from datasets.utils.tqdm import disable_progress_bars + from port_datasets.port_rlds import port_rlds, validate_dataset + + from lerobot.utils.utils import init_logging + + init_logging() + disable_progress_bars() + + shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}" + + try: + validate_dataset(shard_repo_id) + return + except Exception: + pass # nosec B110 - Dataset doesn't exist yet, continue with porting + + port_rlds( + self.raw_dir, + shard_repo_id, + push_to_hub=False, + num_shards=world_size, + shard_index=rank, + ) + + validate_dataset(shard_repo_id) + + def make_port_executor( - raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True + raw_dir, + repo_id, + job_name, + logs_dir, + workers, + partition, + cpus_per_task, + mem_per_cpu, + slurm=True, + dataset_type="droid", + num_shards=None, ): + # Select appropriate pipeline step based on dataset type + if dataset_type.lower() == "droid": + pipeline_step = PortDroidShards(raw_dir, repo_id) + default_shards = DROID_SHARDS + elif dataset_type.lower() == "rlds": + pipeline_step = PortRLDSShards(raw_dir, repo_id, num_shards) + default_shards = num_shards or workers # Use num_shards or fallback to workers + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + kwargs = { - "pipeline": [ - PortDroidShards(raw_dir, repo_id), - ], + "pipeline": [pipeline_step], "logging_dir": str(logs_dir / job_name), } @@ -75,7 +133,7 @@ def make_port_executor( kwargs.update( { "job_name": job_name, - "tasks": DROID_SHARDS, + "tasks": default_shards, "workers": workers, "time": "08:00:00", "partition": partition, @@ -115,11 +173,18 @@ def main(): type=Path, help="Path to logs directory for `datatrove`.", ) + parser.add_argument( + "--dataset-type", + type=str, + choices=["droid", "rlds"], + default="droid", + help="Type of dataset to process: 'droid' for DROID datasets or 'rlds' for RLDS/OpenX datasets.", + ) parser.add_argument( "--job-name", type=str, - default="port_droid", - help="Job name used in slurm, and name of the directory created inside the provided logs directory.", + default=None, + help="Job name used in slurm, and name of the directory created inside the provided logs directory. Defaults to 'port_{dataset_type}'.", ) parser.add_argument( "--slurm", @@ -130,8 +195,14 @@ def main(): parser.add_argument( "--workers", type=int, - default=2048, - help="Number of slurm workers. It should be less than the maximum number of shards.", + default=None, + help="Number of slurm workers. Defaults: 2048 for DROID, 64 for RLDS datasets.", + ) + parser.add_argument( + "--num-shards", + type=int, + default=None, + help="Number of shards to split the dataset into. For DROID datasets, this is fixed at 2048. For RLDS datasets, defaults to number of workers.", ) parser.add_argument( "--partition", @@ -152,8 +223,21 @@ def main(): ) args = parser.parse_args() + + # Set defaults based on dataset type + if args.job_name is None: + args.job_name = f"port_{args.dataset_type}" + + if args.workers is None: + if args.dataset_type == "droid": + args.workers = 2048 + else: # rlds + args.workers = 64 + + # Convert args to kwargs and process kwargs = vars(args) kwargs["slurm"] = kwargs.pop("slurm") == 1 + port_executor = make_port_executor(**kwargs) port_executor.run()