mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-27 03:29:41 +00:00
🐛 Fix bug in robomind2lerobot with empty action_config (#92)
* fix robomind2lerobot bug when action_config is [] * update readme
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import argparse
|
||||
import gc
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
@@ -7,11 +7,12 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import ray
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import VALID_VIDEO_CODECS, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import flatten_dict, validate_episode_buffer, write_info, write_stats
|
||||
from lerobot.datasets.video_utils import get_safe_default_codec
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, flatten_dict, validate_episode_buffer, write_info, write_stats
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from robomind_uitls.configs import ROBOMIND_CONFIG
|
||||
from robomind_uitls.lerobot_uitls import compute_episode_stats, generate_features_from_config
|
||||
@@ -21,6 +22,38 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
|
||||
|
||||
|
||||
class RoboMINDDatasetMetadata(LeRobotDatasetMetadata):
|
||||
def _flush_metadata_buffer(self) -> None:
|
||||
"""Write all buffered episode metadata to parquet file."""
|
||||
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
|
||||
return
|
||||
|
||||
combined_dict = {}
|
||||
for episode_dict in self.metadata_buffer:
|
||||
for key, value in episode_dict.items():
|
||||
if key not in combined_dict:
|
||||
combined_dict[key] = []
|
||||
# Extract value and serialize numpy arrays
|
||||
# because PyArrow's from_pydict function doesn't support numpy arrays
|
||||
val = value[0] if isinstance(value, list) else value
|
||||
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||
|
||||
first_ep = self.metadata_buffer[0]
|
||||
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||
file_idx = first_ep["meta/episodes/file_index"][0]
|
||||
|
||||
schema = None if not self.writer else self.writer.schema
|
||||
table = pa.Table.from_pydict(combined_dict, schema=schema)
|
||||
|
||||
if not self.writer:
|
||||
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||
|
||||
self.writer.write_table(table)
|
||||
|
||||
self.latest_episode = self.metadata_buffer[-1]
|
||||
self.metadata_buffer.clear()
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
split,
|
||||
@@ -57,62 +90,24 @@ class RoboMINDDatasetMetadata(LeRobotDatasetMetadata):
|
||||
|
||||
class RoboMINDDataset(LeRobotDataset):
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
root: str | Path | None = None,
|
||||
robot_type: str | None = None,
|
||||
use_videos: bool = True,
|
||||
tolerance_s: float = 1e-4,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = RoboMINDDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
def create(cls, *args, **kwargs) -> "RoboMINDDataset":
|
||||
sig = inspect.signature(super().create)
|
||||
bound = sig.bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
params = bound.arguments
|
||||
|
||||
obj = super().create(*args, **kwargs)
|
||||
|
||||
shutil.rmtree(params["root"], ignore_errors=True)
|
||||
obj.meta: RoboMINDDatasetMetadata = RoboMINDDatasetMetadata.create(
|
||||
repo_id=params["repo_id"],
|
||||
fps=params["fps"],
|
||||
robot_type=params["robot_type"],
|
||||
features=params["features"],
|
||||
root=params["root"],
|
||||
use_videos=params["use_videos"],
|
||||
metadata_buffer_size=params["metadata_buffer_size"],
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
obj.revision = None
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
obj.batch_encoding_size = batch_encoding_size
|
||||
obj.episodes_since_last_encoding = 0
|
||||
obj.vcodec = vcodec
|
||||
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj.create_episode_buffer()
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = obj.create_hf_dataset()
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj._absolute_to_relative_idx = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj._current_file_start_frame = None
|
||||
# Initialize tracking for incremental recording
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
obj._writer_closed_for_reading = False
|
||||
return obj
|
||||
|
||||
def save_episode(self, split, action_config: dict, episode_data: dict | None = None) -> None:
|
||||
@@ -255,11 +250,11 @@ def save_as_lerobot_dataset(task: tuple[dict, Path, str], src_path, benchmark, e
|
||||
return
|
||||
else:
|
||||
logging.warning(f"Skipped {episode_path}: len of dataset:{len(raw_dataset)} or {str(err)}")
|
||||
gc.collect()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
if dataset.meta.total_episodes == 0:
|
||||
shutil.rmtree(local_dir)
|
||||
del dataset
|
||||
|
||||
|
||||
def main(
|
||||
@@ -298,7 +293,6 @@ def main(
|
||||
logging.error(f"Exception occurred for {task_path['train']}")
|
||||
with open("output.txt", "a") as f:
|
||||
f.write(f"{task_path['train']}, exception details: {str(e)}\n")
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user