mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-11 12:09:41 +00:00
ad1381915c
* update agibot2lerobot * update libero2lerobot * update robomind2lerobot * fix robomind2lerobot
341 lines
14 KiB
Python
341 lines
14 KiB
Python
import argparse
|
|
import inspect
|
|
import shutil
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
import pyarrow.parquet as pq
|
|
import ray
|
|
import torch
|
|
from agibot_utils.agibot_utils import get_task_info, load_local_dataset
|
|
from agibot_utils.config import AgiBotWorld_TASK_TYPE
|
|
from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config
|
|
from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata
|
|
from lerobot.datasets.dataset_writer import DatasetWriter
|
|
from lerobot.datasets.feature_utils import get_hf_features_from_features, validate_episode_buffer, validate_frame
|
|
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
|
|
from ray.runtime_env import RuntimeEnv
|
|
|
|
|
|
class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
|
|
def _flush_metadata_buffer(self) -> None:
|
|
"""Write all buffered episode metadata to parquet file."""
|
|
if not hasattr(self, "_metadata_buffer") or len(self._metadata_buffer) == 0:
|
|
return
|
|
|
|
combined_dict = {}
|
|
for episode_dict in self._metadata_buffer:
|
|
for key, value in episode_dict.items():
|
|
if key not in combined_dict:
|
|
combined_dict[key] = []
|
|
# Extract value and serialize numpy arrays
|
|
# because PyArrow's from_pydict function doesn't support numpy arrays
|
|
val = value[0] if isinstance(value, list) else value
|
|
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
|
|
|
first_ep = self._metadata_buffer[0]
|
|
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
|
file_idx = first_ep["meta/episodes/file_index"][0]
|
|
|
|
schema = None if not self._pq_writer else self._pq_writer.schema
|
|
table = pa.Table.from_pydict(combined_dict, schema=schema)
|
|
|
|
if not self._pq_writer:
|
|
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._pq_writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
|
|
|
self._pq_writer.write_table(table)
|
|
|
|
self.latest_episode = self._metadata_buffer[-1]
|
|
self._metadata_buffer.clear()
|
|
|
|
|
|
class AgiBotDatasetWriter(DatasetWriter):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.hf_features = get_hf_features_from_features(self._meta.features)
|
|
|
|
def add_frame(self, frame: dict) -> None:
|
|
"""
|
|
Add a single frame to the current episode buffer.
|
|
|
|
Apart from images written to a temporary directory, nothing is written to disk
|
|
until ``save_episode()`` is called.
|
|
|
|
The caller must provide all user-defined features plus ``"task"``, and must
|
|
not provide ``"timestamp"`` or ``"frame_index"``; those are computed
|
|
automatically.
|
|
"""
|
|
# Convert torch to numpy if needed
|
|
for name in frame:
|
|
if isinstance(frame[name], torch.Tensor):
|
|
frame[name] = frame[name].numpy()
|
|
|
|
features = {
|
|
key: value for key, value in self._meta.features.items() if key in self.hf_features
|
|
} # remove video keys
|
|
validate_frame(frame, features)
|
|
|
|
if self.episode_buffer is None:
|
|
self.episode_buffer = self._create_episode_buffer()
|
|
|
|
# Automatically add frame_index and timestamp to episode buffer
|
|
frame_index = self.episode_buffer["size"]
|
|
timestamp = frame_index / self._meta.fps
|
|
self.episode_buffer["frame_index"].append(frame_index)
|
|
self.episode_buffer["timestamp"].append(timestamp)
|
|
self.episode_buffer["task"].append(frame.pop("task"))
|
|
|
|
# Add frame features to episode_buffer
|
|
for key, value in frame.items():
|
|
if key not in self._meta.features:
|
|
raise ValueError(
|
|
f"An element of the frame is not in the features. '{key}' not in '{self._meta.features.keys()}'."
|
|
)
|
|
|
|
self.episode_buffer[key].append(value)
|
|
|
|
self.episode_buffer["size"] += 1
|
|
|
|
def save_episode(
|
|
self, videos: dict, action_config: list, episode_data: dict | None = None, parallel_encoding: bool = True
|
|
) -> None:
|
|
"""Save the current episode in self.episode_buffer to disk."""
|
|
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
|
|
|
|
validate_episode_buffer(episode_buffer, self._meta.total_episodes, self._meta.features)
|
|
|
|
# size and task are special cases that won't be added to hf_dataset
|
|
episode_length = episode_buffer.pop("size")
|
|
tasks = episode_buffer.pop("task")
|
|
episode_tasks = list(set(tasks))
|
|
episode_index = episode_buffer["episode_index"]
|
|
|
|
episode_buffer["index"] = np.arange(self._meta.total_frames, self._meta.total_frames + episode_length)
|
|
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
|
|
|
# Update tasks and task indices with new tasks if any
|
|
self._meta.save_episode_tasks(episode_tasks)
|
|
|
|
# Given tasks in natural language, find their corresponding task indices
|
|
episode_buffer["task_index"] = np.array([self._meta.get_task_index(task) for task in tasks])
|
|
|
|
for key, ft in self._meta.features.items():
|
|
# index, episode_index, task_index are already processed above, and image and video
|
|
# are processed separately by storing image path and frame info as meta data
|
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["video"]:
|
|
continue
|
|
episode_buffer[key] = np.stack(episode_buffer[key]).squeeze()
|
|
|
|
for key in self._meta.video_keys:
|
|
episode_buffer[key] = str(videos[key])
|
|
|
|
ep_stats = compute_episode_stats(episode_buffer, self._meta.features)
|
|
|
|
ep_metadata = self._save_episode_data(episode_buffer)
|
|
has_video_keys = len(self._meta.video_keys) > 0
|
|
use_batched_encoding = self._batch_encoding_size > 1
|
|
|
|
self.current_videos = videos
|
|
if has_video_keys and not use_batched_encoding:
|
|
for video_key in self._meta.video_keys:
|
|
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
|
|
|
ep_metadata.update({"action_config": action_config})
|
|
self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
|
|
|
if has_video_keys and use_batched_encoding:
|
|
self._episodes_since_last_encoding += 1
|
|
if self._episodes_since_last_encoding == self._batch_encoding_size:
|
|
start_ep = self._meta.total_episodes - self._batch_encoding_size
|
|
end_ep = self._meta.total_episodes
|
|
self._batch_save_episode_video(start_ep, end_ep)
|
|
self._episodes_since_last_encoding = 0
|
|
|
|
if not episode_data:
|
|
self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0)
|
|
|
|
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
|
"""
|
|
Use ffmpeg to convert frames stored as png into mp4 videos.
|
|
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
|
since video encoding with ffmpeg is already using multithreading.
|
|
"""
|
|
temp_path = Path(tempfile.mkdtemp(dir=self._root)) / f"{video_key}_{episode_index:03d}.mp4"
|
|
shutil.copy(self.current_videos[video_key], temp_path)
|
|
return temp_path
|
|
|
|
|
|
class AgiBotDataset(LeRobotDataset):
|
|
@classmethod
|
|
def create(cls, *args, **kwargs) -> "AgiBotDataset":
|
|
sig = inspect.signature(super().create)
|
|
bound = sig.bind_partial(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
params = bound.arguments
|
|
|
|
obj = super().create(*args, **kwargs)
|
|
|
|
shutil.rmtree(params["root"], ignore_errors=True)
|
|
obj.meta: AgiBotDatasetMetadata = AgiBotDatasetMetadata.create(
|
|
repo_id=params["repo_id"],
|
|
fps=params["fps"],
|
|
robot_type=params["robot_type"],
|
|
features=params["features"],
|
|
root=params["root"],
|
|
use_videos=params["use_videos"],
|
|
metadata_buffer_size=params["metadata_buffer_size"],
|
|
)
|
|
obj.writer: AgiBotDatasetWriter = AgiBotDatasetWriter(
|
|
meta=obj.meta,
|
|
root=obj.root,
|
|
vcodec=obj._vcodec,
|
|
encoder_threads=obj._encoder_threads,
|
|
batch_encoding_size=obj._batch_encoding_size,
|
|
)
|
|
return obj
|
|
|
|
def save_episode(
|
|
self, videos: dict, action_config: list, episode_data: dict | None = None, parallel_encoding: bool = True
|
|
) -> None:
|
|
self._require_writer("save_episode")
|
|
self.writer.save_episode(videos, action_config, episode_data, parallel_encoding)
|
|
|
|
|
|
def get_all_tasks(src_path: Path, output_path: Path):
|
|
json_files = src_path.glob("task_info/*.json")
|
|
for json_file in json_files:
|
|
local_dir = output_path / "agibotworld" / json_file.stem
|
|
yield (json_file, local_dir.resolve())
|
|
|
|
|
|
def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], save_depth):
|
|
json_file, local_dir = task
|
|
print(f"processing {json_file.stem}, saving to {local_dir}")
|
|
src_path = json_file.parent.parent
|
|
task_info = get_task_info(json_file)
|
|
task_name = task_info[0]["task_name"]
|
|
task_init_scene = task_info[0]["init_scene_text"]
|
|
task_instruction = f"{task_name} | {task_init_scene}"
|
|
task_id = json_file.stem.split("_")[-1]
|
|
task_info = {episode["episode_id"]: episode for episode in task_info}
|
|
|
|
features = generate_features_from_config(agibot_world_config)
|
|
|
|
if local_dir.exists():
|
|
shutil.rmtree(local_dir)
|
|
|
|
if not save_depth:
|
|
features.pop("observation.images.head_depth")
|
|
|
|
dataset: AgiBotDataset = AgiBotDataset.create(
|
|
repo_id=json_file.stem,
|
|
root=local_dir,
|
|
fps=30,
|
|
robot_type="a2d",
|
|
features=features,
|
|
)
|
|
|
|
all_subdir = [f.as_posix() for f in src_path.glob(f"observations/{task_id}/*") if f.is_dir()]
|
|
|
|
all_subdir_eids = sorted([int(Path(path).name) for path in all_subdir])
|
|
|
|
for eid in all_subdir_eids:
|
|
if eid not in task_info:
|
|
print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...")
|
|
continue
|
|
action_config = task_info[eid]["label_info"]["action_config"]
|
|
raw_dataset = load_local_dataset(
|
|
eid,
|
|
src_path=src_path,
|
|
task_id=task_id,
|
|
save_depth=save_depth,
|
|
AgiBotWorld_CONFIG=agibot_world_config,
|
|
)
|
|
_, frames, videos = raw_dataset
|
|
if not all([video_path.exists() for video_path in videos.values()]):
|
|
print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping...")
|
|
continue
|
|
|
|
for frame_data in frames:
|
|
frame_data["task"] = task_instruction
|
|
dataset.add_frame(frame_data)
|
|
try:
|
|
dataset.save_episode(videos=videos, action_config=action_config)
|
|
except Exception as e:
|
|
print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}")
|
|
dataset.clear_episode_buffer(delete_images=False)
|
|
continue
|
|
|
|
print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}")
|
|
|
|
dataset.finalize()
|
|
|
|
|
|
def main(
|
|
src_path: str,
|
|
output_path: str,
|
|
eef_type: str,
|
|
task_ids: list,
|
|
cpus_per_task: int,
|
|
save_depth: bool,
|
|
debug: bool = False,
|
|
):
|
|
tasks = get_all_tasks(src_path, output_path)
|
|
|
|
agibot_world_config, type_task_ids = (
|
|
AgiBotWorld_TASK_TYPE[eef_type]["task_config"],
|
|
AgiBotWorld_TASK_TYPE[eef_type]["task_ids"],
|
|
)
|
|
|
|
if eef_type == "gripper":
|
|
remaining_ids = AgiBotWorld_TASK_TYPE["dexhand"]["task_ids"] + AgiBotWorld_TASK_TYPE["tactile"]["task_ids"]
|
|
tasks = filter(lambda task: task[0].stem not in remaining_ids, tasks)
|
|
else:
|
|
tasks = filter(lambda task: task[0].stem in type_task_ids, tasks)
|
|
|
|
if task_ids:
|
|
tasks = filter(lambda task: task[0].stem in task_ids, tasks)
|
|
|
|
if debug:
|
|
save_as_lerobot_dataset(agibot_world_config, next(tasks), save_depth)
|
|
else:
|
|
runtime_env = RuntimeEnv(
|
|
env_vars={"HDF5_USE_FILE_LOCKING": "FALSE", "HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE"}
|
|
)
|
|
ray.init(runtime_env=runtime_env)
|
|
resources = ray.available_resources()
|
|
cpus = int(resources["CPU"])
|
|
|
|
print(f"Available CPUs: {cpus}, num_cpus_per_task: {cpus_per_task}")
|
|
|
|
remote_task = ray.remote(save_as_lerobot_dataset).options(num_cpus=cpus_per_task)
|
|
futures = []
|
|
for task in tasks:
|
|
futures.append((task[0].stem, remote_task.remote(agibot_world_config, task, save_depth)))
|
|
|
|
for task, future in futures:
|
|
try:
|
|
ray.get(future)
|
|
except Exception as e:
|
|
print(f"Exception occurred for {task}")
|
|
with open("output.txt", "a") as f:
|
|
f.write(f"{task}, exception details: {str(e)}\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--src-path", type=Path, required=True)
|
|
parser.add_argument("--output-path", type=Path, required=True)
|
|
parser.add_argument("--eef-type", type=str, choices=["gripper", "dexhand", "tactile"], default="gripper")
|
|
parser.add_argument("--task-ids", type=str, nargs="+", help="task_327 task_351 ...", default=[])
|
|
parser.add_argument("--cpus-per-task", type=int, default=3)
|
|
parser.add_argument("--save-depth", action="store_true")
|
|
parser.add_argument("--debug", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
main(**vars(args))
|