From 8afce7fa23cbf35c1f81df2a2007ae0478922d97 Mon Sep 17 00:00:00 2001 From: Tavish Date: Mon, 22 Jun 2026 13:12:58 +0800 Subject: [PATCH] Refactor AgiBot converter onto generic pipeline Co-authored-by: Codex --- agibot2lerobot/README.md | 18 +- agibot2lerobot/agibot_h5.py | 456 ++++++++++++++------ agibot2lerobot/agibot_utils/agibot_utils.py | 88 +++- 3 files changed, 414 insertions(+), 148 deletions(-) diff --git a/agibot2lerobot/README.md b/agibot2lerobot/README.md index 5c0cdf5..cc0675c 100644 --- a/agibot2lerobot/README.md +++ b/agibot2lerobot/README.md @@ -30,6 +30,7 @@ In this dataset, we have made several key improvements: - **Preservation of Agibot’s Original Information** 🧠: We have preserved as much of Agibot’s original information as possible, with field names strictly adhering to the original dataset’s naming conventions to ensure compatibility and consistency. - **State and Action as Dictionaries** 🧾: The traditional one-dimensional state and action have been transformed into dictionaries, allowing for greater flexibility in designing custom states and actions, enabling modular and scalable handling. +- **Generic Conversion Pipeline**: Conversion now uses the shared `generic_converter` execution flow for local/Ray DataTrove execution, resumable logs, temporary per-task datasets, final aggregation, and optional Hub upload. Dataset Structure of `meta/info.json`: @@ -116,10 +117,11 @@ Dataset Structure of `meta/info.json`: Follow instructions in [official repo](https://github.com/huggingface/lerobot?tab=readme-ov-file#installation). 2. Install others: - We use ray for parallel conversion, significantly speeding up data processing tasks by distributing the workload across multiple cores or nodes (if any). + We use DataTrove for conversion. Install the Ray extra if you want distributed execution across multiple cores or nodes. ```bash pip install h5py - pip install -U "ray[default]" + pip install -U datatrove + pip install -U "datatrove[ray]" # optional, for --executor ray ``` ## Get started @@ -161,12 +163,20 @@ git clone https://github.com/Tavish9/any4lerobot.git There are three types of end-effector, `gripper`, `dexhand` and `tactile`, specify the type before converting +`--output-path` is the final aggregated LeRobot dataset root. Temporary per-task +datasets are written next to it under `_temp` and removed after +aggregation. + +`--episodes-per-task` controls AgiBot conversion granularity. The default `1` +creates one temporary dataset per raw episode for better Ray load balancing. + ```bash -python convert.py \ +python agibot_h5.py \ --src-path /path/to/AgiBotWorld-Beta \ --output-path /path/to/local \ --eef-type gripper \ - --num-cpus-per-task 3 + --cpus-per-task 3 \ + --episodes-per-task 1 ``` ### Execute the script: diff --git a/agibot2lerobot/agibot_h5.py b/agibot2lerobot/agibot_h5.py index f067e44..e000e1a 100644 --- a/agibot2lerobot/agibot_h5.py +++ b/agibot2lerobot/agibot_h5.py @@ -1,22 +1,47 @@ import argparse +import importlib import inspect import shutil +import sys import tempfile +from collections.abc import Sequence from pathlib import Path +from typing import Any import numpy as np import pyarrow as pa import pyarrow.parquet as pq -import ray import torch -from agibot_utils.agibot_utils import get_task_info, load_local_dataset -from agibot_utils.config import AgiBotWorld_TASK_TYPE -from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata from lerobot.datasets.dataset_writer import DatasetWriter -from lerobot.datasets.feature_utils import get_hf_features_from_features, validate_episode_buffer, validate_frame +from lerobot.datasets.feature_utils import ( + get_hf_features_from_features, + validate_episode_buffer, + validate_frame, +) from lerobot.datasets.utils import DEFAULT_EPISODES_PATH -from ray.runtime_env import RuntimeEnv + +AGIBOT_DIR = Path(__file__).resolve().parent +REPO_ROOT = AGIBOT_DIR.parent +for import_path in (REPO_ROOT, AGIBOT_DIR): + import_path_str = str(import_path) + if import_path_str not in sys.path: + sys.path.insert(0, import_path_str) + +from agibot_utils.agibot_utils import ( # noqa: E402 + get_episode_ids, + get_task_id, + get_task_info, + has_episode_videos, + load_local_dataset, +) +from agibot_utils.config import AgiBotWorld_TASK_TYPE # noqa: E402 +from agibot_utils.lerobot_utils import ( # noqa: E402 + compute_episode_stats, + generate_features_from_config, +) + +from generic_converter import BaseAdapter, ConversionTask, run_converter # noqa: E402 class AgiBotDatasetMetadata(LeRobotDatasetMetadata): @@ -33,7 +58,9 @@ class AgiBotDatasetMetadata(LeRobotDatasetMetadata): # Extract value and serialize numpy arrays # because PyArrow's from_pydict function doesn't support numpy arrays val = value[0] if isinstance(value, list) else value - combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) + combined_dict[key].append( + val.tolist() if isinstance(val, np.ndarray) else val + ) first_ep = self._metadata_buffer[0] chunk_idx = first_ep["meta/episodes/chunk_index"][0] @@ -43,9 +70,16 @@ class AgiBotDatasetMetadata(LeRobotDatasetMetadata): table = pa.Table.from_pydict(combined_dict, schema=schema) if not self._pq_writer: - path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) + path = Path( + self.root + / DEFAULT_EPISODES_PATH.format( + chunk_index=chunk_idx, file_index=file_idx + ) + ) path.parent.mkdir(parents=True, exist_ok=True) - self._pq_writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True) + self._pq_writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) self._pq_writer.write_table(table) @@ -75,7 +109,9 @@ class AgiBotDatasetWriter(DatasetWriter): frame[name] = frame[name].numpy() features = { - key: value for key, value in self._meta.features.items() if key in self.hf_features + key: value + for key, value in self._meta.features.items() + if key in self.hf_features } # remove video keys validate_frame(frame, features) @@ -101,12 +137,20 @@ class AgiBotDatasetWriter(DatasetWriter): self.episode_buffer["size"] += 1 def save_episode( - self, videos: dict, action_config: list, episode_data: dict | None = None, parallel_encoding: bool = True + self, + videos: dict, + action_config: list, + episode_data: dict | None = None, + parallel_encoding: bool = True, ) -> None: """Save the current episode in self.episode_buffer to disk.""" - episode_buffer = episode_data if episode_data is not None else self.episode_buffer + episode_buffer = ( + episode_data if episode_data is not None else self.episode_buffer + ) - validate_episode_buffer(episode_buffer, self._meta.total_episodes, self._meta.features) + validate_episode_buffer( + episode_buffer, self._meta.total_episodes, self._meta.features + ) # size and task are special cases that won't be added to hf_dataset episode_length = episode_buffer.pop("size") @@ -114,19 +158,25 @@ class AgiBotDatasetWriter(DatasetWriter): episode_tasks = list(set(tasks)) episode_index = episode_buffer["episode_index"] - episode_buffer["index"] = np.arange(self._meta.total_frames, self._meta.total_frames + episode_length) + episode_buffer["index"] = np.arange( + self._meta.total_frames, self._meta.total_frames + episode_length + ) episode_buffer["episode_index"] = np.full((episode_length,), episode_index) # Update tasks and task indices with new tasks if any self._meta.save_episode_tasks(episode_tasks) # Given tasks in natural language, find their corresponding task indices - episode_buffer["task_index"] = np.array([self._meta.get_task_index(task) for task in tasks]) + episode_buffer["task_index"] = np.array( + [self._meta.get_task_index(task) for task in tasks] + ) for key, ft in self._meta.features.items(): # index, episode_index, task_index are already processed above, and image and video # are processed separately by storing image path and frame info as meta data - if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["video"]: + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [ + "video" + ]: continue episode_buffer[key] = np.stack(episode_buffer[key]).squeeze() @@ -145,7 +195,9 @@ class AgiBotDatasetWriter(DatasetWriter): ep_metadata.update(self._save_episode_video(video_key, episode_index)) ep_metadata.update({"action_config": action_config}) - self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) + self._meta.save_episode( + episode_index, episode_length, episode_tasks, ep_stats, ep_metadata + ) if has_video_keys and use_batched_encoding: self._episodes_since_last_encoding += 1 @@ -158,13 +210,18 @@ class AgiBotDatasetWriter(DatasetWriter): if not episode_data: self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0) - def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: + def _encode_temporary_episode_video( + self, video_key: str, episode_index: int + ) -> Path: """ Use ffmpeg to convert frames stored as png into mp4 videos. Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, since video encoding with ffmpeg is already using multithreading. """ - temp_path = Path(tempfile.mkdtemp(dir=self._root)) / f"{video_key}_{episode_index:03d}.mp4" + temp_path = ( + Path(tempfile.mkdtemp(dir=self._root)) + / f"{video_key}_{episode_index:03d}.mp4" + ) shutil.copy(self.current_videos[video_key], temp_path) return temp_path @@ -188,153 +245,300 @@ class AgiBotDataset(LeRobotDataset): root=params["root"], use_videos=params["use_videos"], metadata_buffer_size=params["metadata_buffer_size"], + video_files_size_in_mb=params["video_files_size_in_mb"], + data_files_size_in_mb=params["data_files_size_in_mb"], ) obj.writer: AgiBotDatasetWriter = AgiBotDatasetWriter( meta=obj.meta, root=obj.root, - vcodec=obj._vcodec, - encoder_threads=obj._encoder_threads, - batch_encoding_size=obj._batch_encoding_size, + camera_encoder=obj.writer._camera_encoder, + encoder_threads=obj.writer._encoder_threads, + batch_encoding_size=obj.writer._batch_encoding_size, + streaming_encoder=obj.writer._streaming_encoder, ) return obj def save_episode( - self, videos: dict, action_config: list, episode_data: dict | None = None, parallel_encoding: bool = True + self, + videos: dict, + action_config: list, + episode_data: dict | None = None, + parallel_encoding: bool = True, ) -> None: self._require_writer("save_episode") self.writer.save_episode(videos, action_config, episode_data, parallel_encoding) -def get_all_tasks(src_path: Path, output_path: Path): - json_files = src_path.glob("task_info/*.json") - for json_file in json_files: - local_dir = output_path / "agibotworld" / json_file.stem - yield (json_file, local_dir.resolve()) +class AgiBotAdapter(BaseAdapter): + dataset_type = "agibot" + fps = 30 + robot_type = "a2d" + tags = ("agibot-world", "a2d") + def __init__( + self, + src_path: Path, + output_path: Path, + eef_type: str, + task_ids: Sequence[str], + save_depth: bool, + episodes_per_task: int, + ): + super().__init__(output_path) + if episodes_per_task < 1: + raise ValueError("--episodes-per-task must be >= 1") + self.src_path = src_path.expanduser().resolve() + self.eef_type = eef_type + self.task_ids = set(task_ids) + self.save_depth = save_depth + self.episodes_per_task = episodes_per_task + self.agibot_world_config = AgiBotWorld_TASK_TYPE[eef_type]["task_config"] + self.type_task_ids = set(AgiBotWorld_TASK_TYPE[eef_type]["task_ids"]) + self.features = generate_features_from_config(self.agibot_world_config) + if not save_depth: + self.features.pop("observation.images.head_depth", None) -def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], save_depth): - json_file, local_dir = task - print(f"processing {json_file.stem}, saving to {local_dir}") - src_path = json_file.parent.parent - task_info = get_task_info(json_file) - task_name = task_info[0]["task_name"] - task_init_scene = task_info[0]["init_scene_text"] - task_instruction = f"{task_name} | {task_init_scene}" - task_id = json_file.stem.split("_")[-1] - task_info = {episode["episode_id"]: episode for episode in task_info} + def load_tasks(self) -> list[ConversionTask]: + tasks = [] + for json_file in sorted(self.src_path.glob("task_info/*.json")): + if not self._include_task(json_file.stem): + continue + _, _, _, task_info = self._load_task_context(json_file) + episode_ids = sorted(task_info) - features = generate_features_from_config(agibot_world_config) + for chunk_index, chunk_episode_ids in enumerate( + self._chunk_episode_ids(episode_ids) + ): + task_name = self._format_task_name( + json_file.stem, chunk_index, chunk_episode_ids + ) + tasks.append( + ConversionTask( + input_path=json_file.resolve(), + output_path=( + self.temp_output_path / "agibotworld" / task_name + ).resolve(), + local_repo_id=task_name, + metadata={"episode_ids": tuple(chunk_episode_ids)}, + ) + ) + return tasks - if local_dir.exists(): - shutil.rmtree(local_dir) - - if not save_depth: - features.pop("observation.images.head_depth") - - dataset: AgiBotDataset = AgiBotDataset.create( - repo_id=json_file.stem, - root=local_dir, - fps=30, - robot_type="a2d", - features=features, - ) - - all_subdir = [f.as_posix() for f in src_path.glob(f"observations/{task_id}/*") if f.is_dir()] - - all_subdir_eids = sorted([int(Path(path).name) for path in all_subdir]) - - for eid in all_subdir_eids: - if eid not in task_info: - print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...") - continue - action_config = task_info[eid]["label_info"]["action_config"] - raw_dataset = load_local_dataset( - eid, - src_path=src_path, - task_id=task_id, - save_depth=save_depth, - AgiBotWorld_CONFIG=agibot_world_config, + def load_subset(self, task: ConversionTask): + json_file = task.input_path + print(f"processing {json_file.stem}, saving to {task.output_path}") + src_path, task_id, task_instruction, task_info = self._load_task_context( + json_file ) - _, frames, videos = raw_dataset - if not all([video_path.exists() for video_path in videos.values()]): - print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping...") - continue - for frame_data in frames: - frame_data["task"] = task_instruction + task_episode_ids = task.metadata.get("episode_ids") + if task_episode_ids is None: + task_episode_ids = get_episode_ids(src_path, task_id) + + for eid in task_episode_ids: + if not self._is_convertible_episode( + json_file.stem, src_path, task_id, eid, task_info + ): + continue + action_config = task_info[eid]["label_info"]["action_config"] + raw_dataset = load_local_dataset( + eid, + src_path=src_path, + task_id=task_id, + save_depth=self.save_depth, + AgiBotWorld_CONFIG=self.agibot_world_config, + ) + if raw_dataset is None: + continue + _, frames, videos = raw_dataset + + for frame_data in frames: + frame_data["task"] = task_instruction + + yield { + "episode_id": eid, + "frames": frames, + "videos": videos, + "action_config": action_config, + } + + def create_dataset(self, task: ConversionTask) -> AgiBotDataset: + return AgiBotDataset.create( + repo_id=task.local_repo_id, + root=task.output_path, + fps=self.fps, + robot_type=self.robot_type, + features=self.features, + ) + + def save_episode( + self, + dataset: AgiBotDataset, + episode_data: dict[str, Any], + task: ConversionTask, + ) -> bool: + for frame_data in episode_data["frames"]: dataset.add_frame(frame_data) try: - dataset.save_episode(videos=videos, action_config=action_config) + dataset.save_episode( + videos=episode_data["videos"], + action_config=episode_data["action_config"], + ) except Exception as e: - print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}") + print( + f"{task.input_path.stem}, episode_{episode_data['episode_id']}: " + f"there are some corrupted mp4s\nException details: {str(e)}" + ) dataset.clear_episode_buffer(delete_images=False) - continue + return False + return True - print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}") + def get_episode_length(self, episode_data: dict[str, Any]) -> int: + return len(episode_data["frames"]) - dataset.finalize() + def _chunk_episode_ids(self, episode_ids: list[int]): + for start in range(0, len(episode_ids), self.episodes_per_task): + yield episode_ids[start : start + self.episodes_per_task] + + def _format_task_name( + self, task_name: str, chunk_index: int, episode_ids: Sequence[int] + ) -> str: + if len(episode_ids) == 1: + return f"{task_name}_episode_{episode_ids[0]}" + return ( + f"{task_name}_chunk_{chunk_index:06d}_episodes_" + f"{episode_ids[0]}_{episode_ids[-1]}" + ) + + def _load_task_context( + self, json_file: Path + ) -> tuple[Path, str, str, dict[int, dict[str, Any]]]: + src_path = json_file.parent.parent + task_id = get_task_id(json_file) + task_info = get_task_info(json_file) + task_name = task_info[0]["task_name"] + task_init_scene = task_info[0]["init_scene_text"] + task_instruction = f"{task_name} | {task_init_scene}" + task_info_by_episode = {episode["episode_id"]: episode for episode in task_info} + return src_path, task_id, task_instruction, task_info_by_episode + + def _is_convertible_episode( + self, + task_name: str, + src_path: Path, + task_id: str, + episode_id: int, + task_info: dict[int, dict[str, Any]], + ) -> bool: + if episode_id not in task_info: + print( + f"{task_name}, episode_{episode_id} not in task_info.json, skipping..." + ) + return False + if not has_episode_videos( + src_path, task_id, episode_id, self.agibot_world_config + ): + print( + f"{task_name}, episode_{episode_id}: " + "some of the videos does not exist, skipping..." + ) + return False + return True + + def _include_task(self, task_id: str) -> bool: + if self.task_ids and task_id not in self.task_ids: + return False + if self.eef_type == "gripper": + remaining_ids = set(AgiBotWorld_TASK_TYPE["dexhand"]["task_ids"]) + remaining_ids.update(AgiBotWorld_TASK_TYPE["tactile"]["task_ids"]) + return task_id not in remaining_ids + return task_id in self.type_task_ids def main( - src_path: str, - output_path: str, + src_path: Path, + output_path: Path, eef_type: str, - task_ids: list, + task_ids: list[str], + executor: str, cpus_per_task: int, + tasks_per_job: int, + workers: int, save_depth: bool, + episodes_per_task: int, + resume_dir: Path | None = None, debug: bool = False, + repo_id: str | None = None, + push_to_hub: bool = False, ): - tasks = get_all_tasks(src_path, output_path) - - agibot_world_config, type_task_ids = ( - AgiBotWorld_TASK_TYPE[eef_type]["task_config"], - AgiBotWorld_TASK_TYPE[eef_type]["task_ids"], + adapter = AgiBotAdapter( + src_path=src_path, + output_path=output_path, + eef_type=eef_type, + task_ids=task_ids, + save_depth=save_depth, + episodes_per_task=episodes_per_task, ) - if eef_type == "gripper": - remaining_ids = AgiBotWorld_TASK_TYPE["dexhand"]["task_ids"] + AgiBotWorld_TASK_TYPE["tactile"]["task_ids"] - tasks = filter(lambda task: task[0].stem not in remaining_ids, tasks) - else: - tasks = filter(lambda task: task[0].stem in type_task_ids, tasks) - - if task_ids: - tasks = filter(lambda task: task[0].stem in task_ids, tasks) - - if debug: - save_as_lerobot_dataset(agibot_world_config, next(tasks), save_depth) - else: - runtime_env = RuntimeEnv( - env_vars={"HDF5_USE_FILE_LOCKING": "FALSE", "HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE"} - ) - ray.init(runtime_env=runtime_env) - resources = ray.available_resources() - cpus = int(resources["CPU"]) - - print(f"Available CPUs: {cpus}, num_cpus_per_task: {cpus_per_task}") - - remote_task = ray.remote(save_as_lerobot_dataset).options(num_cpus=cpus_per_task) - futures = [] - for task in tasks: - futures.append((task[0].stem, remote_task.remote(agibot_world_config, task, save_depth))) - - for task, future in futures: - try: - ray.get(future) - except Exception as e: - print(f"Exception occurred for {task}") - with open("output.txt", "a") as f: - f.write(f"{task}, exception details: {str(e)}\n") + run_converter( + adapter=adapter, + executor=executor, + cpus_per_task=cpus_per_task, + tasks_per_job=tasks_per_job, + workers=workers, + resume_dir=resume_dir, + debug=debug, + local_repo_id=repo_id, + hub_repo_id=repo_id, + push_to_hub=push_to_hub, + extra_tags=[eef_type], + ) -if __name__ == "__main__": +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--src-path", type=Path, required=True) parser.add_argument("--output-path", type=Path, required=True) - parser.add_argument("--eef-type", type=str, choices=["gripper", "dexhand", "tactile"], default="gripper") - parser.add_argument("--task-ids", type=str, nargs="+", help="task_327 task_351 ...", default=[]) - parser.add_argument("--cpus-per-task", type=int, default=3) + parser.add_argument( + "--eef-type", + type=str, + choices=["gripper", "dexhand", "tactile"], + default="gripper", + ) + parser.add_argument( + "--task-ids", type=str, nargs="+", help="task_327 task_351 ...", default=[] + ) + parser.add_argument("--executor", type=str, choices=["local", "ray"], default="ray") + parser.add_argument("--cpus-per-task", type=int, default=1) + parser.add_argument( + "--tasks-per-job", + type=int, + default=1, + help="number of concurrent tasks per job, only used for ray", + ) + parser.add_argument( + "--episodes-per-task", + type=int, + default=10, + help="number of AgiBot episodes grouped into one conversion task", + ) + parser.add_argument( + "--workers", type=int, default=-1, help="number of concurrent jobs to run" + ) + parser.add_argument("--resume-dir", type=Path, help="logs directory to resume") parser.add_argument("--save-depth", action="store_true") parser.add_argument("--debug", action="store_true") - args = parser.parse_args() + parser.add_argument("--repo-id", type=str, help="required when push-to-hub is True") + parser.add_argument("--push-to-hub", action="store_true", help="upload to hub") + return parser.parse_args() - main(**vars(args)) + +def cli(): + args = parse_args() + module_name = "agibot2lerobot.agibot_h5" + module = importlib.import_module(module_name) + module.main(**vars(args)) + + +if __name__ == "__main__": + cli() diff --git a/agibot2lerobot/agibot_utils/agibot_utils.py b/agibot2lerobot/agibot_utils/agibot_utils.py index 54fd3c5..d72f13c 100644 --- a/agibot2lerobot/agibot_utils/agibot_utils.py +++ b/agibot2lerobot/agibot_utils/agibot_utils.py @@ -6,22 +6,67 @@ import numpy as np from PIL import Image -def get_task_info(task_json_path: str) -> dict: +def get_task_info(task_json_path: str | Path) -> list[dict]: with open(task_json_path, "r") as f: task_info: list = json.load(f) task_info.sort(key=lambda episode: episode["episode_id"]) return task_info +def get_task_id(task_json_path: str | Path) -> str: + return Path(task_json_path).stem.split("_")[-1] + + +def get_episode_ids(src_path: str | Path, task_id: str | int) -> list[int]: + observations_dir = Path(src_path) / "observations" / str(task_id) + return sorted( + int(path.name) + for path in observations_dir.glob("*") + if path.is_dir() and path.name.isdigit() + ) + + +def get_episode_videos( + src_path: str | Path, + task_id: str | int, + episode_id: int, + agibot_world_config: dict, +) -> dict[str, Path]: + ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}" + return { + f"observation.images.{key}": ob_dir / "videos" / f"{key}_color.mp4" + if "sensor" not in key + else ob_dir / "tactile" / f"{key}.mp4" # HACK: handle tactile videos + for key in agibot_world_config["images"] + if "depth" not in key + } + + +def has_episode_videos( + src_path: str | Path, + task_id: str | int, + episode_id: int, + agibot_world_config: dict, +) -> bool: + videos = get_episode_videos(src_path, task_id, episode_id, agibot_world_config) + return all(video_path.exists() for video_path in videos.values()) + + def load_depths(root_dir: str, camera_name: str): cam_path = Path(root_dir) all_imgs = sorted(list(cam_path.glob(f"{camera_name}*"))) - return [np.array(Image.open(f)).astype(np.float32)[:, :, None] / 1000 for f in all_imgs] + return [ + np.array(Image.open(f)).astype(np.float32)[:, :, None] / 1000 for f in all_imgs + ] def load_local_dataset( - episode_id: int, src_path: str, task_id: int, save_depth: bool, AgiBotWorld_CONFIG: dict -) -> tuple[list, dict]: + episode_id: int, + src_path: str | Path, + task_id: str | int, + save_depth: bool, + AgiBotWorld_CONFIG: dict, +) -> tuple[int, list[dict], dict[str, Path]] | None: """Load local dataset and return a dict with observations and actions""" ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}" proprio_dir = Path(src_path) / f"proprio_stats/{task_id}/{episode_id}" @@ -30,9 +75,13 @@ def load_local_dataset( action = {} with h5py.File(proprio_dir / "proprio_stats.h5", "r") as f: for key in AgiBotWorld_CONFIG["states"]: - state[f"observation.states.{key}"] = np.array(f["state/" + key.replace(".", "/")], dtype=np.float32) + state[f"observation.states.{key}"] = np.array( + f["state/" + key.replace(".", "/")], dtype=np.float32 + ) for key in AgiBotWorld_CONFIG["actions"]: - action[f"actions.{key}"] = np.array(f["action/" + key.replace(".", "/")], dtype=np.float32) + action[f"actions.{key}"] = np.array( + f["action/" + key.replace(".", "/")], dtype=np.float32 + ) # HACK: agibot team forgot to pad or filter some of the values num_frames = len(next(iter(state.values()))) @@ -42,7 +91,10 @@ def load_local_dataset( elif len(action_value) < num_frames: state_key = action_key.replace("actions", "state").replace(".", "/") new_action_value = np.array(f[state_key], dtype=np.float32).copy() - action_index_key = "/".join(list(action_key.replace("actions", "action").split(".")[:-1]) + ["index"]) + action_index_key = "/".join( + list(action_key.replace("actions", "action").split(".")[:-1]) + + ["index"] + ) action_index = np.array(f[action_index_key]) # agibot lost end index, replace it with joint if not action_index.size: @@ -52,11 +104,13 @@ def load_local_dataset( action[action_key] = new_action_value elif len(action_value) > num_frames: print("corrupt data, skipping") - return episode_id, [], {"dummy_video": Path("/path/to/no_exist")} + return None if save_depth: depth_imgs = load_depths(ob_dir / "depth", "head_depth") - assert num_frames == len(depth_imgs), "Number of images and states are not equal" + assert num_frames == len(depth_imgs), ( + "Number of images and states are not equal" + ) state_key_prefix_len = len("observation.states.") action_key_prefix_len = len("actions.") @@ -68,7 +122,9 @@ def load_local_dataset( if value.size else np.zeros( AgiBotWorld_CONFIG["states"][key[state_key_prefix_len:]]["shape"], - dtype=AgiBotWorld_CONFIG["states"][key[state_key_prefix_len:]]["dtype"], + dtype=AgiBotWorld_CONFIG["states"][key[state_key_prefix_len:]][ + "dtype" + ], ) for key, value in state.items() }, @@ -77,7 +133,9 @@ def load_local_dataset( if value.size else np.zeros( AgiBotWorld_CONFIG["actions"][key[action_key_prefix_len:]]["shape"], - dtype=AgiBotWorld_CONFIG["actions"][key[action_key_prefix_len:]]["dtype"], + dtype=AgiBotWorld_CONFIG["actions"][key[action_key_prefix_len:]][ + "dtype" + ], ) for key, value in action.items() }, @@ -85,11 +143,5 @@ def load_local_dataset( for i in range(num_frames) ] - videos = { - f"observation.images.{key}": ob_dir / "videos" / f"{key}_color.mp4" - if "sensor" not in key - else ob_dir / "tactile" / f"{key}.mp4" # HACK: handle tactile videos - for key in AgiBotWorld_CONFIG["images"] - if "depth" not in key - } + videos = get_episode_videos(src_path, task_id, episode_id, AgiBotWorld_CONFIG) return episode_id, frames, videos