mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
184 lines
5.5 KiB
Python
184 lines
5.5 KiB
Python
|
|
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
import shutil
|
|
from huggingface_hub import snapshot_download
|
|
import tarfile
|
|
|
|
import tqdm
|
|
from examples.port_datasets.agibot_hdf5.port_agibot import port_agibot
|
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.utils.utils import init_logging
|
|
from huggingface_hub import HfApi, HfFileSystem
|
|
|
|
RAW_REPO_ID = "agibot-world/AgiBotWorld-Alpha"
|
|
|
|
|
|
def download(raw_dir, allow_patterns=None, ignore_patterns=None):
|
|
snapshot_download(
|
|
RAW_REPO_ID,
|
|
repo_type="dataset",
|
|
local_dir=str(raw_dir),
|
|
allow_patterns=allow_patterns,
|
|
ignore_patterns=ignore_patterns,
|
|
)
|
|
|
|
def download_proprio_stats(raw_dir):
|
|
proprio_stats_dir = raw_dir / "proprio_stats"
|
|
|
|
if proprio_stats_dir.exists():
|
|
logging.info("Skipping download proprio stats")
|
|
return
|
|
|
|
download(raw_dir, allow_patterns="proprio_stats/*.tar")
|
|
|
|
for path in proprio_stats_dir.glob("*.tar"):
|
|
logging.info(f"Untar-ing {path}...")
|
|
with tarfile.open(path, 'r') as tar:
|
|
tar.extractall(path=proprio_stats_dir)
|
|
|
|
logging.info(f"Deleting {path}...")
|
|
path.unlink()
|
|
|
|
|
|
def download_parameters(raw_dir):
|
|
params_dir = raw_dir / "parameters"
|
|
|
|
if params_dir.exists():
|
|
logging.info("Skipping download parameters")
|
|
return
|
|
|
|
download(raw_dir, allow_patterns="parameters/*.tar")
|
|
|
|
for path in params_dir.glob("*.tar"):
|
|
logging.info(f"Untar-ing {path}...")
|
|
with tarfile.open(path, 'r') as tar:
|
|
tar.extractall(path=params_dir)
|
|
|
|
logging.info(f"Deleting {path}...")
|
|
path.unlink()
|
|
|
|
|
|
def get_observations_files(raw_dir, raw_repo_id):
|
|
files_json_path = raw_dir / "observations_files.json"
|
|
sizes_json_path = raw_dir / "observations_sizes.json"
|
|
if files_json_path.exists() and sizes_json_path.exists():
|
|
with open(files_json_path) as f:
|
|
files = json.load(f)
|
|
with open(sizes_json_path) as f:
|
|
sizes = json.load(f)
|
|
return files, sizes
|
|
|
|
api = HfApi()
|
|
files = api.list_repo_files(repo_id=raw_repo_id, repo_type="dataset")
|
|
files = [file for file in files if "observations/" in file]
|
|
|
|
fs = HfFileSystem()
|
|
sizes = []
|
|
for file in tqdm.tqdm(files, desc="Downloading file sizes"):
|
|
file_info = fs.info(f"datasets/{raw_repo_id}/{file}")
|
|
size = file_info["size"] / 1000**3
|
|
sizes.append(size)
|
|
|
|
# Sort ASC to start with smaller size files
|
|
sizes, files = zip(*sorted(zip(sizes, files)))
|
|
|
|
with open(files_json_path, "w") as f:
|
|
json.dump(files, f)
|
|
with open(sizes_json_path, "w") as f:
|
|
json.dump(sizes, f)
|
|
return files, sizes
|
|
|
|
def display_observations_sizes(files, sizes):
|
|
size_per_task = {}
|
|
for i, (file, size) in enumerate(zip(files, sizes)):
|
|
logging.info(f"{i}/{len(files)}: {file} {size:.2f}GB")
|
|
|
|
task = int(file.split('/')[1])
|
|
|
|
if task not in size_per_task:
|
|
size_per_task[task] = 0
|
|
|
|
size_per_task[task] += size
|
|
|
|
for task, size in size_per_task.items():
|
|
logging.info(f"{task} {size:.2f}GB")
|
|
|
|
total_size = sum(list(size_per_task.values()))
|
|
logging.info(f"Total size: {total_size:.2f}GB")
|
|
|
|
|
|
def download_meta_data(raw_dir):
|
|
# Download task data
|
|
download(raw_dir, allow_patterns="task_info/task_*.json")
|
|
|
|
# Download all camera parameters ~170 GB
|
|
download_parameters(raw_dir)
|
|
|
|
# Download all proprio stats ~26 GB
|
|
download_proprio_stats(raw_dir)
|
|
|
|
def no_depth(tarinfo, path):
|
|
""" Utility to not untar depth data"""
|
|
if "depth" in tarinfo.name:
|
|
return None
|
|
return tarinfo
|
|
|
|
def main():
|
|
init_logging()
|
|
|
|
repo_id = "cadene/agibot_alpha_v30"
|
|
raw_dir = Path("/fsx/remi_cadene/data/AgiBotWorld-Alpha")
|
|
download_meta_data(raw_dir)
|
|
# Get list of tar files containing observation data (containing several episodes each)
|
|
obs_files, obs_sizes = get_observations_files(raw_dir, RAW_REPO_ID)
|
|
display_observations_sizes(obs_files, obs_sizes)
|
|
|
|
shard_indices = range(len(obs_files))
|
|
num_shards = len(obs_files)
|
|
|
|
# TOOD: remove
|
|
obs_files = obs_files[:2]
|
|
shard_indices = [0,1]
|
|
|
|
# Iterate on each subset of episodes
|
|
for shard_index, obs_file in zip(shard_indices, obs_files):
|
|
|
|
shard_repo_id = f"{repo_id}_world_{num_shards}_rank_{shard_index}"
|
|
dataset_dir = HF_LEROBOT_HOME / shard_repo_id
|
|
if dataset_dir.exists():
|
|
shutil.rmtree(dataset_dir)
|
|
|
|
# Download subset
|
|
download(raw_dir, allow_patterns=obs_file)
|
|
|
|
tar_path = raw_dir / obs_file
|
|
with tarfile.open(tar_path, 'r') as tar:
|
|
extracted_files = tar.getnames()
|
|
|
|
task_index = int(tar_path.parent.name)
|
|
episode_names = [int(p) for p in extracted_files if '/' not in p]
|
|
|
|
# Untar if needed
|
|
if not all([(tar_path.parent / f"{ep_name}").exists() for ep_name in episode_names]):
|
|
logging.info(f"Untar-ing {tar_path}...")
|
|
with tarfile.open(tar_path, 'r') as tar:
|
|
tar.extractall(path=tar_path.parent, filter=no_depth)
|
|
|
|
port_agibot(raw_dir, shard_repo_id, task_index, episode_names, push_to_hub=False)
|
|
|
|
for ep_name in episode_names:
|
|
shutil.rmtree(tar_path.parent / f"{ep_name}")
|
|
|
|
tar_path.unlink()
|
|
|
|
# dataset = LeRobotDataset(shard_repo_id, root=dataset_dir)
|
|
# lol=1
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |