mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
444 lines
14 KiB
Python
444 lines
14 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Distributed EgoDex dataset porting using SLURM and datatrove.
|
|
|
|
EgoDex is a large-scale dataset for egocentric dexterous manipulation collected
|
|
with ARKit on Apple Vision Pro. This script converts EgoDex data to LeRobot format.
|
|
|
|
Reference: https://arxiv.org/abs/2505.11709, https://github.com/apple/ml-egodex
|
|
"""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import h5py
|
|
import mediapy as mpy
|
|
import numpy as np
|
|
from datatrove.executor import LocalPipelineExecutor
|
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
|
from datatrove.pipeline.base import PipelineStep
|
|
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
# Image dimensions
|
|
DEFAULT_IMAGE_HEIGHT = 1080
|
|
DEFAULT_IMAGE_WIDTH = 1920
|
|
|
|
class PortEgoDexShards(PipelineStep):
|
|
def __init__(
|
|
self,
|
|
raw_dir: Path | str,
|
|
repo_id: str,
|
|
local_dir: Path | str = None,
|
|
percentage: float = 100.0,
|
|
):
|
|
super().__init__()
|
|
self.raw_dir = Path(raw_dir)
|
|
self.repo_id = repo_id
|
|
self.local_dir = Path(local_dir) if local_dir else Path("data/local_datasets")
|
|
self.percentage = percentage
|
|
|
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import h5py
|
|
import mediapy as mpy
|
|
import numpy as np
|
|
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.utils.utils import init_logging
|
|
|
|
def _get_state_for_single_frame(transforms_group, frame_idx):
|
|
"""
|
|
Construct 48D hand state representation from EgoDex.
|
|
|
|
State vector composition (per hand = 24D, total = 48D):
|
|
- Wrist 3D position (3)
|
|
- Wrist orientation in 6D representation (6)
|
|
- 5 fingertip 3D positions (15)
|
|
"""
|
|
state_vector = []
|
|
fingertip_joints = {
|
|
"left": [
|
|
"leftThumbTip",
|
|
"leftIndexFingerTip",
|
|
"leftMiddleFingerTip",
|
|
"leftRingFingerTip",
|
|
"leftLittleFingerTip",
|
|
],
|
|
"right": [
|
|
"rightThumbTip",
|
|
"rightIndexFingerTip",
|
|
"rightMiddleFingerTip",
|
|
"rightRingFingerTip",
|
|
"rightLittleFingerTip",
|
|
],
|
|
}
|
|
|
|
for hand_side in ["left", "right"]:
|
|
hand_key = f"{hand_side}Hand"
|
|
hand_transform = transforms_group[hand_key][frame_idx]
|
|
|
|
# 1. Wrist 3D position
|
|
hand_position = hand_transform[:3, 3]
|
|
state_vector.extend(hand_position)
|
|
|
|
# 2. Wrist orientation in compact 6D representation
|
|
rotation_matrix = hand_transform[:3, :3]
|
|
rotation_6d = np.concatenate([rotation_matrix[:, 0], rotation_matrix[:, 1]])
|
|
state_vector.extend(rotation_6d)
|
|
|
|
# 3. 3D positions of 5 fingertips
|
|
for fingertip in fingertip_joints[hand_side]:
|
|
fingertip_transform = transforms_group[fingertip][frame_idx]
|
|
fingertip_pos = fingertip_transform[:3, 3]
|
|
state_vector.extend(fingertip_pos)
|
|
|
|
# Also return camera extrinsics for optional coordinate frame transformations
|
|
return np.array(state_vector, dtype=np.float32), transforms_group["camera"][frame_idx]
|
|
|
|
def get_state_and_action_from_egodex_annotations(demo):
|
|
"""
|
|
Convert EgoDex demo annotations into states and actions.
|
|
|
|
The "action" is the state at time t+1 (next-pose prediction).
|
|
"""
|
|
transforms_group = demo["transforms"]
|
|
total_frames = list(transforms_group.values())[0].shape[0]
|
|
|
|
states_list, extrinsics_list = [], []
|
|
for frame_idx in range(total_frames):
|
|
state_vector, extrinsics = _get_state_for_single_frame(transforms_group, frame_idx)
|
|
states_list.append(state_vector)
|
|
extrinsics_list.append(extrinsics.flatten()) # Flatten 4x4 to 16D
|
|
|
|
state = np.array(states_list, dtype=np.float32)
|
|
extrinsics = np.array(extrinsics_list, dtype=np.float32)
|
|
|
|
# Shift by 1 timestep to convert state to action
|
|
action = np.roll(state, -1, axis=0)
|
|
|
|
return state, action, extrinsics
|
|
|
|
def process_demo(hdf5_file_path, video_path):
|
|
"""Process a single EgoDex demo and return frames for LeRobot."""
|
|
video = mpy.read_video(str(video_path))
|
|
video = np.asarray(video)
|
|
num_frames = video.shape[0]
|
|
frames = []
|
|
|
|
with h5py.File(hdf5_file_path, "r") as demo:
|
|
state, action, extrinsics = get_state_and_action_from_egodex_annotations(demo)
|
|
|
|
# Get natural language task description
|
|
if demo.attrs.get("llm_type") == "reversible":
|
|
direction = demo.attrs.get("which_llm_description", "1")
|
|
lang_instruction = demo.attrs.get(
|
|
"llm_description" if direction == "1" else "llm_description2",
|
|
"manipulation task",
|
|
)
|
|
else:
|
|
lang_instruction = demo.attrs.get("llm_description", "manipulation task")
|
|
|
|
for step_idx in range(num_frames):
|
|
# Resize image to default dimensions
|
|
image_resized = cv2.resize(
|
|
video[step_idx],
|
|
(DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT),
|
|
interpolation=cv2.INTER_AREA,
|
|
)
|
|
frame = {
|
|
"task": lang_instruction,
|
|
"observation.image": image_resized,
|
|
"observation.state": state[step_idx],
|
|
"observation.extrinsics": extrinsics[step_idx],
|
|
"action": action[step_idx],
|
|
}
|
|
frames.append(frame)
|
|
|
|
return frames
|
|
|
|
init_logging()
|
|
|
|
# Define EgoDex features
|
|
EGODEX_FEATURES = {
|
|
"observation.image": {
|
|
"dtype": "video",
|
|
"shape": (DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3),
|
|
"names": ["height", "width", "rgb"],
|
|
},
|
|
"observation.state": {
|
|
"dtype": "float32",
|
|
"shape": (48,),
|
|
"names": [
|
|
# Left hand wrist position (3)
|
|
"left_wrist_x",
|
|
"left_wrist_y",
|
|
"left_wrist_z",
|
|
# Left hand wrist rotation 6D (6)
|
|
"left_rot_0",
|
|
"left_rot_1",
|
|
"left_rot_2",
|
|
"left_rot_3",
|
|
"left_rot_4",
|
|
"left_rot_5",
|
|
# Left fingertips (15)
|
|
"left_thumb_x",
|
|
"left_thumb_y",
|
|
"left_thumb_z",
|
|
"left_index_x",
|
|
"left_index_y",
|
|
"left_index_z",
|
|
"left_middle_x",
|
|
"left_middle_y",
|
|
"left_middle_z",
|
|
"left_ring_x",
|
|
"left_ring_y",
|
|
"left_ring_z",
|
|
"left_little_x",
|
|
"left_little_y",
|
|
"left_little_z",
|
|
# Right hand wrist position (3)
|
|
"right_wrist_x",
|
|
"right_wrist_y",
|
|
"right_wrist_z",
|
|
# Right hand wrist rotation 6D (6)
|
|
"right_rot_0",
|
|
"right_rot_1",
|
|
"right_rot_2",
|
|
"right_rot_3",
|
|
"right_rot_4",
|
|
"right_rot_5",
|
|
# Right fingertips (15)
|
|
"right_thumb_x",
|
|
"right_thumb_y",
|
|
"right_thumb_z",
|
|
"right_index_x",
|
|
"right_index_y",
|
|
"right_index_z",
|
|
"right_middle_x",
|
|
"right_middle_y",
|
|
"right_middle_z",
|
|
"right_ring_x",
|
|
"right_ring_y",
|
|
"right_ring_z",
|
|
"right_little_x",
|
|
"right_little_y",
|
|
"right_little_z",
|
|
],
|
|
},
|
|
"observation.extrinsics": {
|
|
"dtype": "float32",
|
|
"shape": (16,),
|
|
"names": [f"extrinsic_{i}" for i in range(16)],
|
|
},
|
|
"action": {
|
|
"dtype": "float32",
|
|
"shape": (48,),
|
|
"names": [f"action_{i}" for i in range(48)],
|
|
},
|
|
}
|
|
|
|
# 1. Discover all HDF5 files
|
|
files = sorted(list(self.raw_dir.rglob("*.hdf5")))
|
|
if not files:
|
|
print(f"No HDF5 files found in {self.raw_dir}")
|
|
return
|
|
|
|
# 2. Apply percentage filter
|
|
if self.percentage < 100:
|
|
num_files = max(1, int(len(files) * self.percentage / 100))
|
|
files = files[:num_files]
|
|
print(f"Processing {self.percentage}% of dataset: {num_files} files")
|
|
|
|
# 3. Assign files to this worker
|
|
my_files = files[rank::world_size]
|
|
if not my_files:
|
|
print(f"Rank {rank} has no files to process.")
|
|
return
|
|
|
|
print(f"Rank {rank} processing {len(my_files)} files out of {len(files)} total.")
|
|
|
|
# 4. Create a LeRobot dataset for this shard
|
|
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
|
|
shard_root = self.local_dir / shard_repo_id if self.local_dir else None
|
|
|
|
dataset = LeRobotDataset.create(
|
|
repo_id=shard_repo_id,
|
|
fps=30,
|
|
robot_type="hand",
|
|
features=EGODEX_FEATURES,
|
|
root=shard_root,
|
|
)
|
|
|
|
# 5. Process each file
|
|
for input_h5 in my_files:
|
|
try:
|
|
# Derive corresponding video path
|
|
video_path = input_h5.with_suffix(".mp4")
|
|
if not video_path.exists():
|
|
print(f"Warning: Video file not found for {input_h5}, skipping.")
|
|
continue
|
|
|
|
# Process demo and add frames
|
|
frames = process_demo(input_h5, video_path)
|
|
for frame in frames:
|
|
dataset.add_frame(frame)
|
|
dataset.save_episode()
|
|
|
|
# Clean up to avoid OOM
|
|
del frames
|
|
|
|
except Exception as e:
|
|
print(f"Error processing {input_h5}: {e}")
|
|
continue
|
|
|
|
# 6. Finalize the dataset
|
|
dataset.finalize()
|
|
|
|
|
|
def make_port_executor(
|
|
raw_dir,
|
|
repo_id,
|
|
job_name,
|
|
logs_dir,
|
|
workers,
|
|
partition,
|
|
cpus_per_task,
|
|
mem_per_cpu,
|
|
local_dir,
|
|
percentage,
|
|
slurm=True,
|
|
):
|
|
kwargs = {
|
|
"pipeline": [
|
|
PortEgoDexShards(raw_dir, repo_id, local_dir, percentage),
|
|
],
|
|
"logging_dir": str(logs_dir / job_name),
|
|
}
|
|
|
|
if slurm:
|
|
kwargs.update(
|
|
{
|
|
"job_name": job_name,
|
|
"tasks": workers,
|
|
"workers": workers,
|
|
"time": "10:00:00", # EgoDex is large, allow more time
|
|
"partition": partition,
|
|
"cpus_per_task": cpus_per_task,
|
|
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
|
}
|
|
)
|
|
executor = SlurmPipelineExecutor(**kwargs)
|
|
else:
|
|
kwargs.update(
|
|
{
|
|
"tasks": workers,
|
|
"workers": 1, # Run locally sequentially for debugging
|
|
}
|
|
)
|
|
executor = LocalPipelineExecutor(**kwargs)
|
|
|
|
return executor
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert EgoDex dataset to LeRobot format using SLURM."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--raw-dir",
|
|
type=Path,
|
|
required=True,
|
|
help="Directory containing input EgoDex data (HDF5 + MP4 files).",
|
|
)
|
|
parser.add_argument(
|
|
"--repo-id",
|
|
type=str,
|
|
required=True,
|
|
help="Repository identifier (e.g., user/egodex-lerobot).",
|
|
)
|
|
parser.add_argument(
|
|
"--logs-dir",
|
|
type=Path,
|
|
default=Path("logs"),
|
|
help="Path to logs directory.",
|
|
)
|
|
parser.add_argument(
|
|
"--job-name",
|
|
type=str,
|
|
default="port_egodex",
|
|
help="Job name used in SLURM.",
|
|
)
|
|
parser.add_argument(
|
|
"--slurm",
|
|
type=int,
|
|
default=1,
|
|
help="Launch over SLURM. Use --slurm 0 to launch sequentially (useful for debugging).",
|
|
)
|
|
parser.add_argument(
|
|
"--workers",
|
|
type=int,
|
|
default=50,
|
|
help="Number of SLURM workers.",
|
|
)
|
|
parser.add_argument(
|
|
"--partition",
|
|
type=str,
|
|
help="SLURM partition.",
|
|
)
|
|
parser.add_argument(
|
|
"--cpus-per-task",
|
|
type=int,
|
|
default=4,
|
|
help="Number of CPUs per worker.",
|
|
)
|
|
parser.add_argument(
|
|
"--mem-per-cpu",
|
|
type=str,
|
|
default="4G",
|
|
help="Memory per CPU.",
|
|
)
|
|
parser.add_argument(
|
|
"--percentage",
|
|
type=float,
|
|
default=100.0,
|
|
help="Percentage of dataset to process (e.g., 1.0 for 1%%). Useful for testing.",
|
|
)
|
|
parser.add_argument(
|
|
"--local-dir",
|
|
type=Path,
|
|
default=None,
|
|
help="Local directory to save the LeRobot dataset. Defaults to data/local_datasets.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
kwargs = vars(args)
|
|
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
|
|
|
port_executor = make_port_executor(**kwargs)
|
|
port_executor.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|