Files
2025-10-29 16:51:15 +08:00

306 lines
10 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
For all datasets in the RLDS format.
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
Example:
python openx_rlds.py \
--raw-dir /path/to/bridge_orig/1.0.0 \
--local-dir /path/to/local_dir \
--repo-id your_id \
--use-videos \
--push-to-hub
"""
import argparse
import re
import shutil
from functools import partial
from pathlib import Path
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from oxe_utils.configs import OXE_DATASET_CONFIGS, ActionEncoding, StateEncoding
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
np.set_printoptions(precision=2)
def transform_raw_dataset(episode, dataset_name):
traj = next(iter(episode["steps"].batch(episode["steps"].cardinality())))
if dataset_name in OXE_STANDARDIZATION_TRANSFORMS:
traj = OXE_STANDARDIZATION_TRANSFORMS[dataset_name](traj)
if dataset_name in OXE_DATASET_CONFIGS:
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
else:
state_obs_keys = [None for _ in range(8)]
proprio = tf.concat(
[
(
tf.zeros((tf.shape(traj["action"])[0], 1), dtype=tf.float32) # padding
if key is None
else tf.cast(traj["observation"][key], tf.float32)
)
for key in state_obs_keys
],
axis=1,
)
traj.update(
{
"proprio": proprio,
"task": traj.pop("language_instruction"),
"action": tf.cast(traj["action"], tf.float32),
}
)
episode["steps"] = traj
return episode
def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bool = True):
dataset_name = Path(builder.data_dir).parent.name
state_names = [f"motor_{i}" for i in range(8)]
if dataset_name in OXE_DATASET_CONFIGS:
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
if state_encoding == StateEncoding.POS_EULER:
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
if "libero" in dataset_name:
state_names = [
"x",
"y",
"z",
"axis_angle1",
"axis_angle2",
"axis_angle3",
"gripper",
"gripper",
] # 2D gripper state
elif state_encoding == StateEncoding.POS_QUAT:
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
elif state_encoding == StateEncoding.JOINT:
state_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
pad_count = state_obs_keys[:-1].count(None)
state_names[-pad_count - 1 : -1] = ["pad"] * pad_count
state_names[-1] = "pad" if state_obs_keys[-1] is None else state_names[-1]
action_names = [f"motor_{i}" for i in range(8)]
if dataset_name in OXE_DATASET_CONFIGS:
action_encoding = OXE_DATASET_CONFIGS[dataset_name]["action_encoding"]
if action_encoding == ActionEncoding.EEF_POS:
action_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]
if "libero" in dataset_name:
action_names = ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3", "gripper"]
elif action_encoding == ActionEncoding.JOINT_POS:
action_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
DEFAULT_FEATURES = {
"observation.state": {
"dtype": "float32",
"shape": (len(state_names),),
"names": {"motors": state_names},
},
"action": {
"dtype": "float32",
"shape": (len(action_names),),
"names": {"motors": action_names},
},
}
obs = builder.info.features["steps"]["observation"]
features = {
f"observation.images.{key}": {
"dtype": "video" if use_videos else "image",
"shape": value.shape,
"names": ["height", "width", "rgb"],
}
for key, value in obs.items()
if "depth" not in key and any(x in key for x in ["image", "rgb"])
}
return {**features, **DEFAULT_FEATURES}
def save_as_lerobot_dataset(lerobot_dataset: LeRobotDataset, raw_dataset: tf.data.Dataset, **kwargs):
for episode in raw_dataset.as_numpy_iterator():
traj = episode["steps"]
for i in range(traj["action"].shape[0]):
image_dict = {
f"observation.images.{key}": value[i]
for key, value in traj["observation"].items()
if "depth" not in key and any(x in key for x in ["image", "rgb"])
}
lerobot_dataset.add_frame(
{
**image_dict,
"observation.state": traj["proprio"][i],
"action": traj["action"][i],
"task": traj["task"][0].decode(),
},
)
lerobot_dataset.save_episode()
def create_lerobot_dataset(
raw_dir: Path,
repo_id: str = None,
local_dir: Path = None,
push_to_hub: bool = False,
fps: int = None,
robot_type: str = None,
use_videos: bool = True,
image_writer_process: int = 5,
image_writer_threads: int = 10,
keep_images: bool = True,
):
last_part = raw_dir.name
if re.match(r"^\d+\.\d+\.\d+$", last_part):
version = last_part
dataset_name = raw_dir.parent.name
data_dir = raw_dir.parent.parent
else:
version = ""
dataset_name = last_part
data_dir = raw_dir.parent
if local_dir is None:
local_dir = Path(HF_LEROBOT_HOME)
local_dir /= f"{dataset_name}_{version}_lerobot"
if local_dir.exists():
shutil.rmtree(local_dir)
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
features = generate_features_from_raw(builder, use_videos)
filter_fn = lambda e: e["success"] if dataset_name == "kuka" else True
raw_dataset = (
builder.as_dataset(split="train")
.filter(filter_fn)
.map(partial(transform_raw_dataset, dataset_name=dataset_name))
)
if fps is None:
if dataset_name in OXE_DATASET_CONFIGS:
fps = OXE_DATASET_CONFIGS[dataset_name]["control_frequency"]
else:
fps = 10
if robot_type is None:
if dataset_name in OXE_DATASET_CONFIGS:
robot_type = OXE_DATASET_CONFIGS[dataset_name]["robot_type"]
robot_type = robot_type.lower().replace(" ", "_").replace("-", "_")
else:
robot_type = "unknown"
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
robot_type=robot_type,
root=local_dir,
fps=int(fps),
use_videos=use_videos,
features=features,
image_writer_threads=image_writer_threads,
image_writer_processes=image_writer_process,
)
save_as_lerobot_dataset(lerobot_dataset, raw_dataset, keep_images=keep_images)
if push_to_hub:
assert repo_id is not None
tags = ["LeRobot", dataset_name, "rlds"]
if dataset_name in OXE_DATASET_CONFIGS:
tags.append("openx")
if robot_type != "unknown":
tags.append(robot_type)
lerobot_dataset.push_to_hub(
tags=tags,
private=False,
push_videos=True,
license="apache-2.0",
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
)
parser.add_argument(
"--local-dir",
type=Path,
required=True,
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Upload to hub.",
)
parser.add_argument(
"--robot-type",
type=str,
default=None,
help="Robot type of this dataset.",
)
parser.add_argument(
"--fps",
type=int,
default=None,
help="Frame rate used to collect videos. Default fps equals to the control frequency of the robot.",
)
parser.add_argument(
"--use-videos",
action="store_true",
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
)
parser.add_argument(
"--image-writer-process",
type=int,
default=5,
help="Number of processes of image writer for saving images.",
)
parser.add_argument(
"--image-writer-threads",
type=int,
default=10,
help="Number of threads per process of image writer for saving images.",
)
args = parser.parse_args()
create_lerobot_dataset(**vars(args))
if __name__ == "__main__":
main()