mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
360 lines
12 KiB
Python
360 lines
12 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import logging
|
|
import re
|
|
import time
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import tensorflow_datasets as tfds
|
|
from oxe_utils.configs import OXE_DATASET_CONFIGS, ActionEncoding, StateEncoding
|
|
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
|
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
|
from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
|
|
|
# Default FPS for datasets without specific config
|
|
DEFAULT_FPS = 10
|
|
DEFAULT_ROBOT_TYPE = "unknown"
|
|
|
|
|
|
def determine_dataset_info(raw_dir: Path):
|
|
"""Determine dataset name and version from directory structure."""
|
|
last_part = raw_dir.name
|
|
if re.match(r"^\d+\.\d+\.\d+$", last_part):
|
|
version = last_part
|
|
dataset_name = raw_dir.parent.name
|
|
data_dir = raw_dir.parent.parent
|
|
else:
|
|
version = ""
|
|
dataset_name = last_part
|
|
data_dir = raw_dir.parent
|
|
return dataset_name, version, data_dir
|
|
|
|
|
|
def generate_features_from_builder(builder: tfds.core.DatasetBuilder, dataset_name: str) -> dict[str, Any]:
|
|
"""Generate LeRobot features schema from TFDS builder and dataset config."""
|
|
|
|
# Generate state names based on encoding type
|
|
state_names = [f"motor_{i}" for i in range(8)]
|
|
if dataset_name in OXE_DATASET_CONFIGS:
|
|
state_encoding = OXE_DATASET_CONFIGS[dataset_name]["state_encoding"]
|
|
if state_encoding == StateEncoding.POS_EULER:
|
|
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "pad", "gripper"]
|
|
if "libero" in dataset_name:
|
|
state_names = [
|
|
"x",
|
|
"y",
|
|
"z",
|
|
"roll",
|
|
"pitch",
|
|
"yaw",
|
|
"gripper",
|
|
"gripper",
|
|
] # 2D gripper state
|
|
elif state_encoding == StateEncoding.POS_QUAT:
|
|
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
|
|
elif state_encoding == StateEncoding.JOINT:
|
|
state_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
|
|
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
|
|
pad_count = state_obs_keys[:-1].count(None)
|
|
state_names[-pad_count - 1 : -1] = ["pad"] * pad_count
|
|
state_names[-1] = "pad" if state_obs_keys[-1] is None else state_names[-1]
|
|
|
|
# Generate action names based on encoding type
|
|
action_names = [f"motor_{i}" for i in range(8)]
|
|
if dataset_name in OXE_DATASET_CONFIGS:
|
|
action_encoding = OXE_DATASET_CONFIGS[dataset_name]["action_encoding"]
|
|
if action_encoding == ActionEncoding.EEF_POS:
|
|
action_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]
|
|
elif action_encoding == ActionEncoding.JOINT_POS:
|
|
action_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
|
|
|
|
# Base features (state and action)
|
|
features = {
|
|
"observation.state": {
|
|
"dtype": "float32",
|
|
"shape": (len(state_names),),
|
|
"names": {"axes": state_names},
|
|
},
|
|
"action": {
|
|
"dtype": "float32",
|
|
"shape": (len(action_names),),
|
|
"names": {"axes": action_names},
|
|
},
|
|
}
|
|
|
|
# Add image features from TFDS builder info
|
|
obs_features = builder.info.features["steps"]["observation"]
|
|
for key, value in obs_features.items():
|
|
# Skip depth images and non-image features
|
|
if "depth" in key or not any(x in key for x in ["image", "rgb"]):
|
|
continue
|
|
|
|
features[f"observation.images.{key}"] = {
|
|
"dtype": "video",
|
|
"shape": tuple(value.shape),
|
|
"names": ["height", "width", "channels"],
|
|
}
|
|
|
|
return features
|
|
|
|
|
|
def transform_raw_dataset(episode, dataset_name: str):
|
|
"""Apply OXE standardization transforms to raw TFDS episode."""
|
|
# Batch all steps in the episode
|
|
traj = next(iter(episode["steps"].batch(episode["steps"].cardinality())))
|
|
|
|
# Apply dataset-specific transform if available
|
|
if dataset_name in OXE_STANDARDIZATION_TRANSFORMS:
|
|
traj = OXE_STANDARDIZATION_TRANSFORMS[dataset_name](traj)
|
|
|
|
# Create consolidated state vector
|
|
if dataset_name in OXE_DATASET_CONFIGS:
|
|
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
|
|
else:
|
|
state_obs_keys = [None for _ in range(8)]
|
|
|
|
# Build proprio (proprioceptive state) vector
|
|
proprio_components = []
|
|
for key in state_obs_keys:
|
|
if key is None:
|
|
# Add padding for missing state components
|
|
component = tf.zeros((tf.shape(traj["action"])[0], 1), dtype=tf.float32)
|
|
else:
|
|
component = tf.cast(traj["observation"][key], tf.float32)
|
|
# Ensure component has right shape (add dimension if needed)
|
|
if len(component.shape) == 1:
|
|
component = component[:, None]
|
|
proprio_components.append(component)
|
|
|
|
proprio = tf.concat(proprio_components, axis=1)
|
|
|
|
# Update trajectory with standardized format
|
|
traj.update(
|
|
{
|
|
"proprio": proprio,
|
|
"task": traj.get("language_instruction", ""),
|
|
"action": tf.cast(traj["action"], tf.float32),
|
|
}
|
|
)
|
|
|
|
episode["steps"] = traj
|
|
return episode
|
|
|
|
|
|
def generate_lerobot_frames(tf_episode):
|
|
"""Generate LeRobot frames from transformed TFDS episode."""
|
|
traj = tf_episode["steps"]
|
|
|
|
# Get the task/language instruction
|
|
if isinstance(traj["task"], tf.Tensor):
|
|
if traj["task"].dtype == tf.string:
|
|
task = traj["task"][0].numpy().decode() if len(traj["task"]) > 0 else ""
|
|
else:
|
|
task = str(traj["task"][0].numpy()) if len(traj["task"]) > 0 else ""
|
|
else:
|
|
task = str(traj["task"]) if traj["task"] else ""
|
|
|
|
# Iterate through each timestep
|
|
num_steps = tf.shape(traj["action"])[0].numpy()
|
|
for i in range(num_steps):
|
|
frame = {}
|
|
|
|
# Add observation state
|
|
frame["observation.state"] = traj["proprio"][i].numpy()
|
|
|
|
# Add action
|
|
frame["action"] = traj["action"][i].numpy()
|
|
|
|
# Add images
|
|
for key, value in traj["observation"].items():
|
|
if any(x in key for x in ["image", "rgb"]) and "depth" not in key:
|
|
frame[f"observation.images.{key}"] = value[i].numpy()
|
|
|
|
# Add task
|
|
frame["task"] = task
|
|
|
|
# Cast fp64 to fp32
|
|
for key in frame:
|
|
if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
|
|
frame[key] = frame[key].astype(np.float32)
|
|
|
|
yield frame
|
|
|
|
|
|
def port_rlds(
|
|
raw_dir: Path,
|
|
repo_id: str,
|
|
push_to_hub: bool = False,
|
|
num_shards: int | None = None,
|
|
shard_index: int | None = None,
|
|
):
|
|
"""Port RLDS dataset to LeRobot format."""
|
|
|
|
# Determine dataset info
|
|
dataset_name, version, data_dir = determine_dataset_info(raw_dir)
|
|
|
|
# Build TFDS dataset
|
|
builder = tfds.builder(
|
|
f"{dataset_name}/{version}" if version else dataset_name, data_dir=data_dir, version=version
|
|
)
|
|
|
|
# Handle sharding if specified
|
|
if num_shards is not None and shard_index is not None:
|
|
if shard_index >= num_shards:
|
|
raise ValueError(f"Shard index {shard_index} >= num_shards {num_shards}")
|
|
|
|
# Calculate shard splits
|
|
total_episodes = builder.info.splits["train"].num_examples
|
|
episodes_per_shard = total_episodes // num_shards
|
|
start_idx = shard_index * episodes_per_shard
|
|
if shard_index == num_shards - 1:
|
|
# Last shard gets remaining episodes
|
|
end_idx = total_episodes
|
|
else:
|
|
end_idx = start_idx + episodes_per_shard
|
|
|
|
split_str = f"train[{start_idx}:{end_idx}]"
|
|
raw_dataset = builder.as_dataset(split=split_str)
|
|
else:
|
|
raw_dataset = builder.as_dataset(split="train")
|
|
|
|
# Apply filtering (e.g., success filter for kuka)
|
|
if dataset_name == "kuka":
|
|
raw_dataset = raw_dataset.filter(lambda e: e["success"])
|
|
|
|
# Apply transformations
|
|
raw_dataset = raw_dataset.map(partial(transform_raw_dataset, dataset_name=dataset_name))
|
|
|
|
# Get dataset configuration
|
|
fps = DEFAULT_FPS
|
|
robot_type = DEFAULT_ROBOT_TYPE
|
|
|
|
if dataset_name in OXE_DATASET_CONFIGS:
|
|
config = OXE_DATASET_CONFIGS[dataset_name]
|
|
fps = config.get("control_frequency", DEFAULT_FPS)
|
|
robot_type = config.get("robot_type", DEFAULT_ROBOT_TYPE)
|
|
robot_type = robot_type.lower().replace(" ", "_").replace("-", "_")
|
|
|
|
# Generate features schema
|
|
features = generate_features_from_builder(builder, dataset_name)
|
|
|
|
# Create LeRobot dataset
|
|
lerobot_dataset = LeRobotDataset.create(
|
|
repo_id=repo_id,
|
|
robot_type=robot_type,
|
|
fps=int(fps),
|
|
features=features,
|
|
)
|
|
|
|
# Process episodes
|
|
start_time = time.time()
|
|
num_episodes = raw_dataset.cardinality().numpy().item()
|
|
logging.info(f"Number of episodes: {num_episodes}")
|
|
|
|
for episode_index, episode in enumerate(raw_dataset):
|
|
elapsed_time = time.time() - start_time
|
|
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
|
|
|
|
logging.info(
|
|
f"{episode_index} / {num_episodes} episodes processed "
|
|
f"(after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
|
|
)
|
|
|
|
# Generate and add frames
|
|
for frame in generate_lerobot_frames(episode):
|
|
lerobot_dataset.add_frame(frame)
|
|
|
|
lerobot_dataset.save_episode()
|
|
logging.info("Save_episode")
|
|
|
|
# Push to hub if requested
|
|
if push_to_hub:
|
|
tags = ["openx", dataset_name]
|
|
if robot_type != "unknown":
|
|
tags.append(robot_type)
|
|
|
|
lerobot_dataset.push_to_hub(
|
|
tags=tags,
|
|
private=False,
|
|
)
|
|
|
|
|
|
def validate_dataset(repo_id):
|
|
"""Sanity check that ensures metadata can be loaded and all files are present."""
|
|
meta = LeRobotDatasetMetadata(repo_id)
|
|
|
|
if meta.total_episodes == 0:
|
|
raise ValueError("Number of episodes is 0.")
|
|
|
|
for ep_idx in range(meta.total_episodes):
|
|
data_path = meta.root / meta.get_data_file_path(ep_idx)
|
|
|
|
if not data_path.exists():
|
|
raise ValueError(f"Parquet file is missing in: {data_path}")
|
|
|
|
for vid_key in meta.video_keys:
|
|
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
|
|
if not vid_path.exists():
|
|
raise ValueError(f"Video file is missing in: {vid_path}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--raw-dir",
|
|
type=Path,
|
|
required=True,
|
|
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
|
)
|
|
parser.add_argument(
|
|
"--repo-id",
|
|
type=str,
|
|
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
|
|
)
|
|
parser.add_argument(
|
|
"--push-to-hub",
|
|
action="store_true",
|
|
help="Upload to hub.",
|
|
)
|
|
parser.add_argument(
|
|
"--num-shards",
|
|
type=int,
|
|
default=None,
|
|
help="Number of shards to split the dataset into for parallel processing.",
|
|
)
|
|
parser.add_argument(
|
|
"--shard-index",
|
|
type=int,
|
|
default=None,
|
|
help="Index of the shard to process (0-indexed).",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
port_rlds(**vars(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|