diff --git a/agibot2lerobot/agibot_h5.py b/agibot2lerobot/agibot_h5.py index b096244..53a1687 100644 --- a/agibot2lerobot/agibot_h5.py +++ b/agibot2lerobot/agibot_h5.py @@ -1,21 +1,78 @@ import argparse -import gc +import inspect import shutil import tempfile from pathlib import Path import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq import ray import torch from agibot_utils.agibot_utils import get_task_info, load_local_dataset from agibot_utils.config import AgiBotWorld_TASK_TYPE from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import validate_episode_buffer, validate_frame +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, validate_episode_buffer, validate_frame from ray.runtime_env import RuntimeEnv +class AgiBotDatasetMetadata(LeRobotDatasetMetadata): + def _flush_metadata_buffer(self) -> None: + """Write all buffered episode metadata to parquet file.""" + if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + return + + combined_dict = {} + for episode_dict in self.metadata_buffer: + for key, value in episode_dict.items(): + if key not in combined_dict: + combined_dict[key] = [] + # Extract value and serialize numpy arrays + # because PyArrow's from_pydict function doesn't support numpy arrays + val = value[0] if isinstance(value, list) else value + combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) + + first_ep = self.metadata_buffer[0] + chunk_idx = first_ep["meta/episodes/chunk_index"][0] + file_idx = first_ep["meta/episodes/file_index"][0] + + schema = None if not self.writer else self.writer.schema + table = pa.Table.from_pydict(combined_dict, schema=schema) + + if not self.writer: + path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) + path.parent.mkdir(parents=True, exist_ok=True) + self.writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True) + + self.writer.write_table(table) + + self.latest_episode = self.metadata_buffer[-1] + self.metadata_buffer.clear() + + class AgiBotDataset(LeRobotDataset): + @classmethod + def create(cls, *args, **kwargs) -> "AgiBotDataset": + sig = inspect.signature(super().create) + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + params = bound.arguments + + obj = super().create(*args, **kwargs) + + shutil.rmtree(params["root"], ignore_errors=True) + obj.meta: AgiBotDatasetMetadata = AgiBotDatasetMetadata.create( + repo_id=params["repo_id"], + fps=params["fps"], + robot_type=params.get("robot_type"), + features=params["features"], + root=params.get("root"), + use_videos=params.get("use_videos", True), + metadata_buffer_size=params.get("metadata_buffer_size", 10), + ) + return obj + def add_frame(self, frame: dict) -> None: """ This function only adds the frame to the episode_buffer. Apart from images — which are written in a @@ -191,11 +248,13 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], save_d dataset.save_episode(videos=videos, action_config=action_config) except Exception as e: print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}") - dataset.episode_buffer = None + dataset.clear_episode_buffer(delete_images=False) continue - gc.collect() + print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}") + dataset.finalize() + def main( src_path: str, @@ -247,8 +306,6 @@ def main( with open("output.txt", "a") as f: f.write(f"{task}, exception details: {str(e)}\n") - ray.shutdown() - if __name__ == "__main__": parser = argparse.ArgumentParser()