diff --git a/agibot2lerobot/agibot_h5.py b/agibot2lerobot/agibot_h5.py index 1b4c55f..1ae58a2 100644 --- a/agibot2lerobot/agibot_h5.py +++ b/agibot2lerobot/agibot_h5.py @@ -6,48 +6,113 @@ from concurrent.futures import ( as_completed, ) from pathlib import Path -from typing import Callable import numpy as np import ray import torch -from agibot_utils.agibot_utils import get_task_instruction, load_local_dataset +from agibot_utils.agibot_utils import get_task_info, load_local_dataset from agibot_utils.config import AgiBotWorld_TASK_TYPE from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.compute_stats import aggregate_stats +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.utils import ( check_timestamps_sync, get_episode_data_index, validate_episode_buffer, validate_frame, + write_episode, + write_episode_stats, + write_info, ) +from lerobot.common.datasets.video_utils import get_safe_default_codec +from lerobot.common.robot_devices.robots.utils import Robot from ray.runtime_env import RuntimeEnv -class AgiBotDataset(LeRobotDataset): - def __init__( +class AgiBotDatasetMetadata(LeRobotDatasetMetadata): + def save_episode( self, + episode_index: int, + episode_length: int, + episode_tasks: list[str], + episode_stats: dict[str, dict], + action_config: list[dict], + ) -> None: + self.info["total_episodes"] += 1 + self.info["total_frames"] += episode_length + + chunk = self.get_episode_chunk(episode_index) + if chunk >= self.total_chunks: + self.info["total_chunks"] += 1 + + self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} + self.info["total_videos"] += len(self.video_keys) + if len(self.video_keys) > 0: + self.update_video_info() + + write_info(self.info, self.root) + + episode_dict = { + "episode_index": episode_index, + "tasks": episode_tasks, + "length": episode_length, + "action_config": action_config, + } + self.episodes[episode_index] = episode_dict + write_episode(episode_dict, self.root) + + self.episodes_stats[episode_index] = episode_stats + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats + write_episode_stats(episode_index, episode_stats, self.root) + + +class AgiBotDataset(LeRobotDataset): + @classmethod + def create( + cls, repo_id: str, + fps: int, root: str | Path | None = None, - episodes: list[int] | None = None, - image_transforms: Callable | None = None, - delta_timestamps: dict[list[float]] | None = None, + robot: Robot | None = None, + robot_type: str | None = None, + features: dict | None = None, + use_videos: bool = True, tolerance_s: float = 1e-4, - download_videos: bool = True, - local_files_only: bool = False, + image_writer_processes: int = 0, + image_writer_threads: int = 0, video_backend: str | None = None, - ): - super().__init__( + ) -> "LeRobotDataset": + """Create a LeRobot Dataset from scratch in order to record data.""" + obj = cls.__new__(cls) + obj.meta = AgiBotDatasetMetadata.create( repo_id=repo_id, + fps=fps, root=root, - episodes=episodes, - image_transforms=image_transforms, - delta_timestamps=delta_timestamps, - tolerance_s=tolerance_s, - download_videos=download_videos, - local_files_only=local_files_only, - video_backend=video_backend, + robot=robot, + robot_type=robot_type, + features=features, + use_videos=use_videos, ) + obj.repo_id = obj.meta.repo_id + obj.root = obj.meta.root + obj.revision = None + obj.tolerance_s = tolerance_s + obj.image_writer = None + + if image_writer_processes or image_writer_threads: + obj.start_image_writer(image_writer_processes, image_writer_threads) + + # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer + obj.episode_buffer = obj.create_episode_buffer() + + obj.episodes = None + obj.hf_dataset = obj.create_hf_dataset() + obj.image_transforms = None + obj.delta_timestamps = None + obj.delta_indices = None + obj.episode_data_index = None + obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + return obj def add_frame(self, frame: dict) -> None: """ @@ -88,7 +153,7 @@ class AgiBotDataset(LeRobotDataset): self.episode_buffer["size"] += 1 - def save_episode(self, episode_data: dict | None = None, videos: dict | None = None) -> None: + def save_episode(self, videos: dict, action_config: list, episode_data: dict | None = None) -> None: """ This will save to disk the current episode in self.episode_buffer. @@ -138,7 +203,8 @@ class AgiBotDataset(LeRobotDataset): self._save_episode_table(episode_buffer, episode_index) # `meta.save_episode` be executed after encoding the videos - self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) + # add action_config to current episode + self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, action_config) ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} @@ -165,8 +231,13 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], num_th json_file, local_dir = task print(f"processing {json_file.stem}, saving to {local_dir}") src_path = json_file.parent.parent - task_name = get_task_instruction(json_file) + task_info = get_task_info(json_file) + task_name = task_info[0]["task_name"] + task_init_scene = task_info[0]["init_scene_text"] + task_instruction = f"{task_name}.{task_init_scene}" task_id = json_file.stem.split("_")[-1] + task_info = {episode["episode_id"]: episode for episode in task_info} + features = generate_features_from_config(agibot_world_config) if local_dir.exists(): @@ -185,27 +256,31 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], num_th all_subdir = [f.as_posix() for f in src_path.glob(f"observations/{task_id}/*") if f.is_dir()] - all_subdir_eids = [int(Path(path).name) for path in all_subdir] + all_subdir_eids = sorted([int(Path(path).name) for path in all_subdir]) if debug or not save_depth: for eid in all_subdir_eids: + if eid not in task_info: + print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...") + continue try: + action_config = task_info[eid]["label_info"]["action_config"] raw_dataset = load_local_dataset( eid, src_path=src_path, task_id=task_id, - task_name=task_name, + task_instruction=task_instruction, save_depth=save_depth, AgiBotWorld_CONFIG=agibot_world_config, ) - frames, videos = raw_dataset + _, frames, videos = raw_dataset if not all([video_path.exists() for video_path in videos.values()]): - print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping") + print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping...") continue for frame_data in frames: dataset.add_frame(frame_data) - dataset.save_episode(videos=videos) + dataset.save_episode(videos=videos, action_config=action_config) except Exception as e: raise Exception(f"{json_file.stem}, {eid}") from e gc.collect() @@ -213,24 +288,28 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], num_th else: with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [] - for episode_id in all_subdir_eids: + for eid in all_subdir_eids: + if eid not in task_info: + print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...") + continue futures.append( executor.submit( load_local_dataset, - episode_id, + eid, src_path=src_path, task_id=task_id, - task_name=task_name, + task_instruction=task_instruction, save_depth=save_depth, AgiBotWorld_CONFIG=agibot_world_config, ) ) for raw_dataset in as_completed(futures): - frames, videos = raw_dataset.result() + eid, frames, videos = raw_dataset.result() + action_config = task_info[eid]["label_info"]["action_config"] for frame_data in frames: dataset.add_frame(frame_data) - dataset.save_episode(videos=videos) + dataset.save_episode(videos=videos, action_config=action_config) gc.collect() diff --git a/agibot2lerobot/agibot_utils/agibot_utils.py b/agibot2lerobot/agibot_utils/agibot_utils.py index 85d63ce..f425175 100644 --- a/agibot2lerobot/agibot_utils/agibot_utils.py +++ b/agibot2lerobot/agibot_utils/agibot_utils.py @@ -6,14 +6,11 @@ import numpy as np from PIL import Image -def get_task_instruction(task_json_path: str) -> dict: - """Get task language instruction""" +def get_task_info(task_json_path: str) -> dict: with open(task_json_path, "r") as f: - task_info = json.load(f) - task_name = task_info[0]["task_name"] - task_init_scene = task_info[0]["init_scene_text"] - task_instruction = f"{task_name}.{task_init_scene}" - return task_instruction + task_info: list = json.load(f) + task_info.sort(key=lambda episode: episode["episode_id"]) + return task_info def load_depths(root_dir: str, camera_name: str): @@ -23,7 +20,7 @@ def load_depths(root_dir: str, camera_name: str): def load_local_dataset( - episode_id: int, src_path: str, task_id: int, task_name: str, save_depth: bool, AgiBotWorld_CONFIG: dict + episode_id: int, src_path: str, task_id: int, task_instruction: str, save_depth: bool, AgiBotWorld_CONFIG: dict ) -> tuple[list, dict]: """Load local dataset and return a dict with observations and actions""" ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}" @@ -79,7 +76,7 @@ def load_local_dataset( ) for key, value in action.items() }, - "task": task_name, + "task": task_instruction, } for i in range(num_frames) ] @@ -91,4 +88,4 @@ def load_local_dataset( for key in AgiBotWorld_CONFIG["images"] if "depth" not in key } - return frames, videos + return episode_id, frames, videos