mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-25 10:39:44 +00:00
🐛 fix agibot2lerobot bug when action_config is [] (#91)
Co-authored-by: Brucexkm <24210860082@m.fudan.edu.cn>
This commit is contained in:
@@ -1,21 +1,78 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc
|
import inspect
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
from agibot_utils.agibot_utils import get_task_info, load_local_dataset
|
from agibot_utils.agibot_utils import get_task_info, load_local_dataset
|
||||||
from agibot_utils.config import AgiBotWorld_TASK_TYPE
|
from agibot_utils.config import AgiBotWorld_TASK_TYPE
|
||||||
from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config
|
from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.datasets.utils import validate_episode_buffer, validate_frame
|
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, validate_episode_buffer, validate_frame
|
||||||
from ray.runtime_env import RuntimeEnv
|
from ray.runtime_env import RuntimeEnv
|
||||||
|
|
||||||
|
|
||||||
|
class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
|
||||||
|
def _flush_metadata_buffer(self) -> None:
|
||||||
|
"""Write all buffered episode metadata to parquet file."""
|
||||||
|
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
combined_dict = {}
|
||||||
|
for episode_dict in self.metadata_buffer:
|
||||||
|
for key, value in episode_dict.items():
|
||||||
|
if key not in combined_dict:
|
||||||
|
combined_dict[key] = []
|
||||||
|
# Extract value and serialize numpy arrays
|
||||||
|
# because PyArrow's from_pydict function doesn't support numpy arrays
|
||||||
|
val = value[0] if isinstance(value, list) else value
|
||||||
|
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||||
|
|
||||||
|
first_ep = self.metadata_buffer[0]
|
||||||
|
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||||
|
file_idx = first_ep["meta/episodes/file_index"][0]
|
||||||
|
|
||||||
|
schema = None if not self.writer else self.writer.schema
|
||||||
|
table = pa.Table.from_pydict(combined_dict, schema=schema)
|
||||||
|
|
||||||
|
if not self.writer:
|
||||||
|
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||||
|
|
||||||
|
self.writer.write_table(table)
|
||||||
|
|
||||||
|
self.latest_episode = self.metadata_buffer[-1]
|
||||||
|
self.metadata_buffer.clear()
|
||||||
|
|
||||||
|
|
||||||
class AgiBotDataset(LeRobotDataset):
|
class AgiBotDataset(LeRobotDataset):
|
||||||
|
@classmethod
|
||||||
|
def create(cls, *args, **kwargs) -> "AgiBotDataset":
|
||||||
|
sig = inspect.signature(super().create)
|
||||||
|
bound = sig.bind_partial(*args, **kwargs)
|
||||||
|
bound.apply_defaults()
|
||||||
|
params = bound.arguments
|
||||||
|
|
||||||
|
obj = super().create(*args, **kwargs)
|
||||||
|
|
||||||
|
shutil.rmtree(params["root"], ignore_errors=True)
|
||||||
|
obj.meta: AgiBotDatasetMetadata = AgiBotDatasetMetadata.create(
|
||||||
|
repo_id=params["repo_id"],
|
||||||
|
fps=params["fps"],
|
||||||
|
robot_type=params.get("robot_type"),
|
||||||
|
features=params["features"],
|
||||||
|
root=params.get("root"),
|
||||||
|
use_videos=params.get("use_videos", True),
|
||||||
|
metadata_buffer_size=params.get("metadata_buffer_size", 10),
|
||||||
|
)
|
||||||
|
return obj
|
||||||
|
|
||||||
def add_frame(self, frame: dict) -> None:
|
def add_frame(self, frame: dict) -> None:
|
||||||
"""
|
"""
|
||||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||||
@@ -191,11 +248,13 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], save_d
|
|||||||
dataset.save_episode(videos=videos, action_config=action_config)
|
dataset.save_episode(videos=videos, action_config=action_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}")
|
print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}")
|
||||||
dataset.episode_buffer = None
|
dataset.clear_episode_buffer(delete_images=False)
|
||||||
continue
|
continue
|
||||||
gc.collect()
|
|
||||||
print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}")
|
print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}")
|
||||||
|
|
||||||
|
dataset.finalize()
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
src_path: str,
|
src_path: str,
|
||||||
@@ -247,8 +306,6 @@ def main(
|
|||||||
with open("output.txt", "a") as f:
|
with open("output.txt", "a") as f:
|
||||||
f.write(f"{task}, exception details: {str(e)}\n")
|
f.write(f"{task}, exception details: {str(e)}\n")
|
||||||
|
|
||||||
ray.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
Reference in New Issue
Block a user