🐛 fix agibot2lerobot bug when action_config is [] (#91)

Co-authored-by: Brucexkm <24210860082@m.fudan.edu.cn>
This commit is contained in:
Qizhi Chen
2026-03-19 12:02:11 +08:00
committed by GitHub
parent 8dd339f545
commit 8cc8f342a4
+64 -7
View File
@@ -1,21 +1,78 @@
import argparse
import gc
import inspect
import shutil
import tempfile
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import ray
import torch
from agibot_utils.agibot_utils import get_task_info, load_local_dataset
from agibot_utils.config import AgiBotWorld_TASK_TYPE
from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import validate_episode_buffer, validate_frame
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, validate_episode_buffer, validate_frame
from ray.runtime_env import RuntimeEnv
class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
def _flush_metadata_buffer(self) -> None:
"""Write all buffered episode metadata to parquet file."""
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
return
combined_dict = {}
for episode_dict in self.metadata_buffer:
for key, value in episode_dict.items():
if key not in combined_dict:
combined_dict[key] = []
# Extract value and serialize numpy arrays
# because PyArrow's from_pydict function doesn't support numpy arrays
val = value[0] if isinstance(value, list) else value
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
first_ep = self.metadata_buffer[0]
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
file_idx = first_ep["meta/episodes/file_index"][0]
schema = None if not self.writer else self.writer.schema
table = pa.Table.from_pydict(combined_dict, schema=schema)
if not self.writer:
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
path.parent.mkdir(parents=True, exist_ok=True)
self.writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
self.writer.write_table(table)
self.latest_episode = self.metadata_buffer[-1]
self.metadata_buffer.clear()
class AgiBotDataset(LeRobotDataset):
@classmethod
def create(cls, *args, **kwargs) -> "AgiBotDataset":
sig = inspect.signature(super().create)
bound = sig.bind_partial(*args, **kwargs)
bound.apply_defaults()
params = bound.arguments
obj = super().create(*args, **kwargs)
shutil.rmtree(params["root"], ignore_errors=True)
obj.meta: AgiBotDatasetMetadata = AgiBotDatasetMetadata.create(
repo_id=params["repo_id"],
fps=params["fps"],
robot_type=params.get("robot_type"),
features=params["features"],
root=params.get("root"),
use_videos=params.get("use_videos", True),
metadata_buffer_size=params.get("metadata_buffer_size", 10),
)
return obj
def add_frame(self, frame: dict) -> None:
"""
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
@@ -191,11 +248,13 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], save_d
dataset.save_episode(videos=videos, action_config=action_config)
except Exception as e:
print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}")
dataset.episode_buffer = None
dataset.clear_episode_buffer(delete_images=False)
continue
gc.collect()
print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}")
dataset.finalize()
def main(
src_path: str,
@@ -247,8 +306,6 @@ def main(
with open("output.txt", "a") as f:
f.write(f"{task}, exception details: {str(e)}\n")
ray.shutdown()
if __name__ == "__main__":
parser = argparse.ArgumentParser()