mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-11 12:09:41 +00:00
🐛 fix agibot2lerobot bug when action_config is [] (#91)
Co-authored-by: Brucexkm <24210860082@m.fudan.edu.cn>
This commit is contained in:
@@ -1,21 +1,78 @@
|
||||
import argparse
|
||||
import gc
|
||||
import inspect
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import ray
|
||||
import torch
|
||||
from agibot_utils.agibot_utils import get_task_info, load_local_dataset
|
||||
from agibot_utils.config import AgiBotWorld_TASK_TYPE
|
||||
from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import validate_episode_buffer, validate_frame
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, validate_episode_buffer, validate_frame
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
|
||||
|
||||
class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
|
||||
def _flush_metadata_buffer(self) -> None:
|
||||
"""Write all buffered episode metadata to parquet file."""
|
||||
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
|
||||
return
|
||||
|
||||
combined_dict = {}
|
||||
for episode_dict in self.metadata_buffer:
|
||||
for key, value in episode_dict.items():
|
||||
if key not in combined_dict:
|
||||
combined_dict[key] = []
|
||||
# Extract value and serialize numpy arrays
|
||||
# because PyArrow's from_pydict function doesn't support numpy arrays
|
||||
val = value[0] if isinstance(value, list) else value
|
||||
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||
|
||||
first_ep = self.metadata_buffer[0]
|
||||
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||
file_idx = first_ep["meta/episodes/file_index"][0]
|
||||
|
||||
schema = None if not self.writer else self.writer.schema
|
||||
table = pa.Table.from_pydict(combined_dict, schema=schema)
|
||||
|
||||
if not self.writer:
|
||||
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||
|
||||
self.writer.write_table(table)
|
||||
|
||||
self.latest_episode = self.metadata_buffer[-1]
|
||||
self.metadata_buffer.clear()
|
||||
|
||||
|
||||
class AgiBotDataset(LeRobotDataset):
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs) -> "AgiBotDataset":
|
||||
sig = inspect.signature(super().create)
|
||||
bound = sig.bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
params = bound.arguments
|
||||
|
||||
obj = super().create(*args, **kwargs)
|
||||
|
||||
shutil.rmtree(params["root"], ignore_errors=True)
|
||||
obj.meta: AgiBotDatasetMetadata = AgiBotDatasetMetadata.create(
|
||||
repo_id=params["repo_id"],
|
||||
fps=params["fps"],
|
||||
robot_type=params.get("robot_type"),
|
||||
features=params["features"],
|
||||
root=params.get("root"),
|
||||
use_videos=params.get("use_videos", True),
|
||||
metadata_buffer_size=params.get("metadata_buffer_size", 10),
|
||||
)
|
||||
return obj
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||
@@ -191,11 +248,13 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], save_d
|
||||
dataset.save_episode(videos=videos, action_config=action_config)
|
||||
except Exception as e:
|
||||
print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}")
|
||||
dataset.episode_buffer = None
|
||||
dataset.clear_episode_buffer(delete_images=False)
|
||||
continue
|
||||
gc.collect()
|
||||
|
||||
print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}")
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
def main(
|
||||
src_path: str,
|
||||
@@ -247,8 +306,6 @@ def main(
|
||||
with open("output.txt", "a") as f:
|
||||
f.write(f"{task}, exception details: {str(e)}\n")
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
Reference in New Issue
Block a user