mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-06-23 14:27:00 +00:00
Refactor AgiBot converter onto generic pipeline
Co-authored-by: Codex <codex@openai.com>
This commit is contained in:
@@ -30,6 +30,7 @@ In this dataset, we have made several key improvements:
|
||||
|
||||
- **Preservation of Agibot’s Original Information** 🧠: We have preserved as much of Agibot’s original information as possible, with field names strictly adhering to the original dataset’s naming conventions to ensure compatibility and consistency.
|
||||
- **State and Action as Dictionaries** 🧾: The traditional one-dimensional state and action have been transformed into dictionaries, allowing for greater flexibility in designing custom states and actions, enabling modular and scalable handling.
|
||||
- **Generic Conversion Pipeline**: Conversion now uses the shared `generic_converter` execution flow for local/Ray DataTrove execution, resumable logs, temporary per-task datasets, final aggregation, and optional Hub upload.
|
||||
|
||||
Dataset Structure of `meta/info.json`:
|
||||
|
||||
@@ -116,10 +117,11 @@ Dataset Structure of `meta/info.json`:
|
||||
Follow instructions in [official repo](https://github.com/huggingface/lerobot?tab=readme-ov-file#installation).
|
||||
|
||||
2. Install others:
|
||||
We use ray for parallel conversion, significantly speeding up data processing tasks by distributing the workload across multiple cores or nodes (if any).
|
||||
We use DataTrove for conversion. Install the Ray extra if you want distributed execution across multiple cores or nodes.
|
||||
```bash
|
||||
pip install h5py
|
||||
pip install -U "ray[default]"
|
||||
pip install -U datatrove
|
||||
pip install -U "datatrove[ray]" # optional, for --executor ray
|
||||
```
|
||||
|
||||
## Get started
|
||||
@@ -161,12 +163,20 @@ git clone https://github.com/Tavish9/any4lerobot.git
|
||||
|
||||
There are three types of end-effector, `gripper`, `dexhand` and `tactile`, specify the type before converting
|
||||
|
||||
`--output-path` is the final aggregated LeRobot dataset root. Temporary per-task
|
||||
datasets are written next to it under `<output-name>_temp` and removed after
|
||||
aggregation.
|
||||
|
||||
`--episodes-per-task` controls AgiBot conversion granularity. The default `1`
|
||||
creates one temporary dataset per raw episode for better Ray load balancing.
|
||||
|
||||
```bash
|
||||
python convert.py \
|
||||
python agibot_h5.py \
|
||||
--src-path /path/to/AgiBotWorld-Beta \
|
||||
--output-path /path/to/local \
|
||||
--eef-type gripper \
|
||||
--num-cpus-per-task 3
|
||||
--cpus-per-task 3 \
|
||||
--episodes-per-task 1
|
||||
```
|
||||
|
||||
### Execute the script:
|
||||
|
||||
+330
-126
@@ -1,22 +1,47 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import inspect
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import ray
|
||||
import torch
|
||||
from agibot_utils.agibot_utils import get_task_info, load_local_dataset
|
||||
from agibot_utils.config import AgiBotWorld_TASK_TYPE
|
||||
from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config
|
||||
from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_writer import DatasetWriter
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features, validate_episode_buffer, validate_frame
|
||||
from lerobot.datasets.feature_utils import (
|
||||
get_hf_features_from_features,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
)
|
||||
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
|
||||
AGIBOT_DIR = Path(__file__).resolve().parent
|
||||
REPO_ROOT = AGIBOT_DIR.parent
|
||||
for import_path in (REPO_ROOT, AGIBOT_DIR):
|
||||
import_path_str = str(import_path)
|
||||
if import_path_str not in sys.path:
|
||||
sys.path.insert(0, import_path_str)
|
||||
|
||||
from agibot_utils.agibot_utils import ( # noqa: E402
|
||||
get_episode_ids,
|
||||
get_task_id,
|
||||
get_task_info,
|
||||
has_episode_videos,
|
||||
load_local_dataset,
|
||||
)
|
||||
from agibot_utils.config import AgiBotWorld_TASK_TYPE # noqa: E402
|
||||
from agibot_utils.lerobot_utils import ( # noqa: E402
|
||||
compute_episode_stats,
|
||||
generate_features_from_config,
|
||||
)
|
||||
|
||||
from generic_converter import BaseAdapter, ConversionTask, run_converter # noqa: E402
|
||||
|
||||
|
||||
class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
|
||||
@@ -33,7 +58,9 @@ class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
|
||||
# Extract value and serialize numpy arrays
|
||||
# because PyArrow's from_pydict function doesn't support numpy arrays
|
||||
val = value[0] if isinstance(value, list) else value
|
||||
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||
combined_dict[key].append(
|
||||
val.tolist() if isinstance(val, np.ndarray) else val
|
||||
)
|
||||
|
||||
first_ep = self._metadata_buffer[0]
|
||||
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||
@@ -43,9 +70,16 @@ class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
|
||||
table = pa.Table.from_pydict(combined_dict, schema=schema)
|
||||
|
||||
if not self._pq_writer:
|
||||
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||
path = Path(
|
||||
self.root
|
||||
/ DEFAULT_EPISODES_PATH.format(
|
||||
chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._pq_writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||
self._pq_writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
|
||||
self._pq_writer.write_table(table)
|
||||
|
||||
@@ -75,7 +109,9 @@ class AgiBotDatasetWriter(DatasetWriter):
|
||||
frame[name] = frame[name].numpy()
|
||||
|
||||
features = {
|
||||
key: value for key, value in self._meta.features.items() if key in self.hf_features
|
||||
key: value
|
||||
for key, value in self._meta.features.items()
|
||||
if key in self.hf_features
|
||||
} # remove video keys
|
||||
validate_frame(frame, features)
|
||||
|
||||
@@ -101,12 +137,20 @@ class AgiBotDatasetWriter(DatasetWriter):
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(
|
||||
self, videos: dict, action_config: list, episode_data: dict | None = None, parallel_encoding: bool = True
|
||||
self,
|
||||
videos: dict,
|
||||
action_config: list,
|
||||
episode_data: dict | None = None,
|
||||
parallel_encoding: bool = True,
|
||||
) -> None:
|
||||
"""Save the current episode in self.episode_buffer to disk."""
|
||||
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
|
||||
episode_buffer = (
|
||||
episode_data if episode_data is not None else self.episode_buffer
|
||||
)
|
||||
|
||||
validate_episode_buffer(episode_buffer, self._meta.total_episodes, self._meta.features)
|
||||
validate_episode_buffer(
|
||||
episode_buffer, self._meta.total_episodes, self._meta.features
|
||||
)
|
||||
|
||||
# size and task are special cases that won't be added to hf_dataset
|
||||
episode_length = episode_buffer.pop("size")
|
||||
@@ -114,19 +158,25 @@ class AgiBotDatasetWriter(DatasetWriter):
|
||||
episode_tasks = list(set(tasks))
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
|
||||
episode_buffer["index"] = np.arange(self._meta.total_frames, self._meta.total_frames + episode_length)
|
||||
episode_buffer["index"] = np.arange(
|
||||
self._meta.total_frames, self._meta.total_frames + episode_length
|
||||
)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Update tasks and task indices with new tasks if any
|
||||
self._meta.save_episode_tasks(episode_tasks)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self._meta.get_task_index(task) for task in tasks])
|
||||
episode_buffer["task_index"] = np.array(
|
||||
[self._meta.get_task_index(task) for task in tasks]
|
||||
)
|
||||
|
||||
for key, ft in self._meta.features.items():
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["video"]:
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [
|
||||
"video"
|
||||
]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key]).squeeze()
|
||||
|
||||
@@ -145,7 +195,9 @@ class AgiBotDatasetWriter(DatasetWriter):
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
ep_metadata.update({"action_config": action_config})
|
||||
self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
self._meta.save_episode(
|
||||
episode_index, episode_length, episode_tasks, ep_stats, ep_metadata
|
||||
)
|
||||
|
||||
if has_video_keys and use_batched_encoding:
|
||||
self._episodes_since_last_encoding += 1
|
||||
@@ -158,13 +210,18 @@ class AgiBotDatasetWriter(DatasetWriter):
|
||||
if not episode_data:
|
||||
self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0)
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
def _encode_temporary_episode_video(
|
||||
self, video_key: str, episode_index: int
|
||||
) -> Path:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self._root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
temp_path = (
|
||||
Path(tempfile.mkdtemp(dir=self._root))
|
||||
/ f"{video_key}_{episode_index:03d}.mp4"
|
||||
)
|
||||
shutil.copy(self.current_videos[video_key], temp_path)
|
||||
return temp_path
|
||||
|
||||
@@ -188,153 +245,300 @@ class AgiBotDataset(LeRobotDataset):
|
||||
root=params["root"],
|
||||
use_videos=params["use_videos"],
|
||||
metadata_buffer_size=params["metadata_buffer_size"],
|
||||
video_files_size_in_mb=params["video_files_size_in_mb"],
|
||||
data_files_size_in_mb=params["data_files_size_in_mb"],
|
||||
)
|
||||
obj.writer: AgiBotDatasetWriter = AgiBotDatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
vcodec=obj._vcodec,
|
||||
encoder_threads=obj._encoder_threads,
|
||||
batch_encoding_size=obj._batch_encoding_size,
|
||||
camera_encoder=obj.writer._camera_encoder,
|
||||
encoder_threads=obj.writer._encoder_threads,
|
||||
batch_encoding_size=obj.writer._batch_encoding_size,
|
||||
streaming_encoder=obj.writer._streaming_encoder,
|
||||
)
|
||||
return obj
|
||||
|
||||
def save_episode(
|
||||
self, videos: dict, action_config: list, episode_data: dict | None = None, parallel_encoding: bool = True
|
||||
self,
|
||||
videos: dict,
|
||||
action_config: list,
|
||||
episode_data: dict | None = None,
|
||||
parallel_encoding: bool = True,
|
||||
) -> None:
|
||||
self._require_writer("save_episode")
|
||||
self.writer.save_episode(videos, action_config, episode_data, parallel_encoding)
|
||||
|
||||
|
||||
def get_all_tasks(src_path: Path, output_path: Path):
|
||||
json_files = src_path.glob("task_info/*.json")
|
||||
for json_file in json_files:
|
||||
local_dir = output_path / "agibotworld" / json_file.stem
|
||||
yield (json_file, local_dir.resolve())
|
||||
class AgiBotAdapter(BaseAdapter):
|
||||
dataset_type = "agibot"
|
||||
fps = 30
|
||||
robot_type = "a2d"
|
||||
tags = ("agibot-world", "a2d")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
src_path: Path,
|
||||
output_path: Path,
|
||||
eef_type: str,
|
||||
task_ids: Sequence[str],
|
||||
save_depth: bool,
|
||||
episodes_per_task: int,
|
||||
):
|
||||
super().__init__(output_path)
|
||||
if episodes_per_task < 1:
|
||||
raise ValueError("--episodes-per-task must be >= 1")
|
||||
self.src_path = src_path.expanduser().resolve()
|
||||
self.eef_type = eef_type
|
||||
self.task_ids = set(task_ids)
|
||||
self.save_depth = save_depth
|
||||
self.episodes_per_task = episodes_per_task
|
||||
self.agibot_world_config = AgiBotWorld_TASK_TYPE[eef_type]["task_config"]
|
||||
self.type_task_ids = set(AgiBotWorld_TASK_TYPE[eef_type]["task_ids"])
|
||||
self.features = generate_features_from_config(self.agibot_world_config)
|
||||
if not save_depth:
|
||||
self.features.pop("observation.images.head_depth", None)
|
||||
|
||||
def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], save_depth):
|
||||
json_file, local_dir = task
|
||||
print(f"processing {json_file.stem}, saving to {local_dir}")
|
||||
src_path = json_file.parent.parent
|
||||
task_info = get_task_info(json_file)
|
||||
task_name = task_info[0]["task_name"]
|
||||
task_init_scene = task_info[0]["init_scene_text"]
|
||||
task_instruction = f"{task_name} | {task_init_scene}"
|
||||
task_id = json_file.stem.split("_")[-1]
|
||||
task_info = {episode["episode_id"]: episode for episode in task_info}
|
||||
def load_tasks(self) -> list[ConversionTask]:
|
||||
tasks = []
|
||||
for json_file in sorted(self.src_path.glob("task_info/*.json")):
|
||||
if not self._include_task(json_file.stem):
|
||||
continue
|
||||
_, _, _, task_info = self._load_task_context(json_file)
|
||||
episode_ids = sorted(task_info)
|
||||
|
||||
features = generate_features_from_config(agibot_world_config)
|
||||
for chunk_index, chunk_episode_ids in enumerate(
|
||||
self._chunk_episode_ids(episode_ids)
|
||||
):
|
||||
task_name = self._format_task_name(
|
||||
json_file.stem, chunk_index, chunk_episode_ids
|
||||
)
|
||||
tasks.append(
|
||||
ConversionTask(
|
||||
input_path=json_file.resolve(),
|
||||
output_path=(
|
||||
self.temp_output_path / "agibotworld" / task_name
|
||||
).resolve(),
|
||||
local_repo_id=task_name,
|
||||
metadata={"episode_ids": tuple(chunk_episode_ids)},
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
if local_dir.exists():
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
if not save_depth:
|
||||
features.pop("observation.images.head_depth")
|
||||
|
||||
dataset: AgiBotDataset = AgiBotDataset.create(
|
||||
repo_id=json_file.stem,
|
||||
root=local_dir,
|
||||
fps=30,
|
||||
robot_type="a2d",
|
||||
features=features,
|
||||
)
|
||||
|
||||
all_subdir = [f.as_posix() for f in src_path.glob(f"observations/{task_id}/*") if f.is_dir()]
|
||||
|
||||
all_subdir_eids = sorted([int(Path(path).name) for path in all_subdir])
|
||||
|
||||
for eid in all_subdir_eids:
|
||||
if eid not in task_info:
|
||||
print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...")
|
||||
continue
|
||||
action_config = task_info[eid]["label_info"]["action_config"]
|
||||
raw_dataset = load_local_dataset(
|
||||
eid,
|
||||
src_path=src_path,
|
||||
task_id=task_id,
|
||||
save_depth=save_depth,
|
||||
AgiBotWorld_CONFIG=agibot_world_config,
|
||||
def load_subset(self, task: ConversionTask):
|
||||
json_file = task.input_path
|
||||
print(f"processing {json_file.stem}, saving to {task.output_path}")
|
||||
src_path, task_id, task_instruction, task_info = self._load_task_context(
|
||||
json_file
|
||||
)
|
||||
_, frames, videos = raw_dataset
|
||||
if not all([video_path.exists() for video_path in videos.values()]):
|
||||
print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping...")
|
||||
continue
|
||||
|
||||
for frame_data in frames:
|
||||
frame_data["task"] = task_instruction
|
||||
task_episode_ids = task.metadata.get("episode_ids")
|
||||
if task_episode_ids is None:
|
||||
task_episode_ids = get_episode_ids(src_path, task_id)
|
||||
|
||||
for eid in task_episode_ids:
|
||||
if not self._is_convertible_episode(
|
||||
json_file.stem, src_path, task_id, eid, task_info
|
||||
):
|
||||
continue
|
||||
action_config = task_info[eid]["label_info"]["action_config"]
|
||||
raw_dataset = load_local_dataset(
|
||||
eid,
|
||||
src_path=src_path,
|
||||
task_id=task_id,
|
||||
save_depth=self.save_depth,
|
||||
AgiBotWorld_CONFIG=self.agibot_world_config,
|
||||
)
|
||||
if raw_dataset is None:
|
||||
continue
|
||||
_, frames, videos = raw_dataset
|
||||
|
||||
for frame_data in frames:
|
||||
frame_data["task"] = task_instruction
|
||||
|
||||
yield {
|
||||
"episode_id": eid,
|
||||
"frames": frames,
|
||||
"videos": videos,
|
||||
"action_config": action_config,
|
||||
}
|
||||
|
||||
def create_dataset(self, task: ConversionTask) -> AgiBotDataset:
|
||||
return AgiBotDataset.create(
|
||||
repo_id=task.local_repo_id,
|
||||
root=task.output_path,
|
||||
fps=self.fps,
|
||||
robot_type=self.robot_type,
|
||||
features=self.features,
|
||||
)
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
dataset: AgiBotDataset,
|
||||
episode_data: dict[str, Any],
|
||||
task: ConversionTask,
|
||||
) -> bool:
|
||||
for frame_data in episode_data["frames"]:
|
||||
dataset.add_frame(frame_data)
|
||||
try:
|
||||
dataset.save_episode(videos=videos, action_config=action_config)
|
||||
dataset.save_episode(
|
||||
videos=episode_data["videos"],
|
||||
action_config=episode_data["action_config"],
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}")
|
||||
print(
|
||||
f"{task.input_path.stem}, episode_{episode_data['episode_id']}: "
|
||||
f"there are some corrupted mp4s\nException details: {str(e)}"
|
||||
)
|
||||
dataset.clear_episode_buffer(delete_images=False)
|
||||
continue
|
||||
return False
|
||||
return True
|
||||
|
||||
print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}")
|
||||
def get_episode_length(self, episode_data: dict[str, Any]) -> int:
|
||||
return len(episode_data["frames"])
|
||||
|
||||
dataset.finalize()
|
||||
def _chunk_episode_ids(self, episode_ids: list[int]):
|
||||
for start in range(0, len(episode_ids), self.episodes_per_task):
|
||||
yield episode_ids[start : start + self.episodes_per_task]
|
||||
|
||||
def _format_task_name(
|
||||
self, task_name: str, chunk_index: int, episode_ids: Sequence[int]
|
||||
) -> str:
|
||||
if len(episode_ids) == 1:
|
||||
return f"{task_name}_episode_{episode_ids[0]}"
|
||||
return (
|
||||
f"{task_name}_chunk_{chunk_index:06d}_episodes_"
|
||||
f"{episode_ids[0]}_{episode_ids[-1]}"
|
||||
)
|
||||
|
||||
def _load_task_context(
|
||||
self, json_file: Path
|
||||
) -> tuple[Path, str, str, dict[int, dict[str, Any]]]:
|
||||
src_path = json_file.parent.parent
|
||||
task_id = get_task_id(json_file)
|
||||
task_info = get_task_info(json_file)
|
||||
task_name = task_info[0]["task_name"]
|
||||
task_init_scene = task_info[0]["init_scene_text"]
|
||||
task_instruction = f"{task_name} | {task_init_scene}"
|
||||
task_info_by_episode = {episode["episode_id"]: episode for episode in task_info}
|
||||
return src_path, task_id, task_instruction, task_info_by_episode
|
||||
|
||||
def _is_convertible_episode(
|
||||
self,
|
||||
task_name: str,
|
||||
src_path: Path,
|
||||
task_id: str,
|
||||
episode_id: int,
|
||||
task_info: dict[int, dict[str, Any]],
|
||||
) -> bool:
|
||||
if episode_id not in task_info:
|
||||
print(
|
||||
f"{task_name}, episode_{episode_id} not in task_info.json, skipping..."
|
||||
)
|
||||
return False
|
||||
if not has_episode_videos(
|
||||
src_path, task_id, episode_id, self.agibot_world_config
|
||||
):
|
||||
print(
|
||||
f"{task_name}, episode_{episode_id}: "
|
||||
"some of the videos does not exist, skipping..."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _include_task(self, task_id: str) -> bool:
|
||||
if self.task_ids and task_id not in self.task_ids:
|
||||
return False
|
||||
if self.eef_type == "gripper":
|
||||
remaining_ids = set(AgiBotWorld_TASK_TYPE["dexhand"]["task_ids"])
|
||||
remaining_ids.update(AgiBotWorld_TASK_TYPE["tactile"]["task_ids"])
|
||||
return task_id not in remaining_ids
|
||||
return task_id in self.type_task_ids
|
||||
|
||||
|
||||
def main(
|
||||
src_path: str,
|
||||
output_path: str,
|
||||
src_path: Path,
|
||||
output_path: Path,
|
||||
eef_type: str,
|
||||
task_ids: list,
|
||||
task_ids: list[str],
|
||||
executor: str,
|
||||
cpus_per_task: int,
|
||||
tasks_per_job: int,
|
||||
workers: int,
|
||||
save_depth: bool,
|
||||
episodes_per_task: int,
|
||||
resume_dir: Path | None = None,
|
||||
debug: bool = False,
|
||||
repo_id: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
):
|
||||
tasks = get_all_tasks(src_path, output_path)
|
||||
|
||||
agibot_world_config, type_task_ids = (
|
||||
AgiBotWorld_TASK_TYPE[eef_type]["task_config"],
|
||||
AgiBotWorld_TASK_TYPE[eef_type]["task_ids"],
|
||||
adapter = AgiBotAdapter(
|
||||
src_path=src_path,
|
||||
output_path=output_path,
|
||||
eef_type=eef_type,
|
||||
task_ids=task_ids,
|
||||
save_depth=save_depth,
|
||||
episodes_per_task=episodes_per_task,
|
||||
)
|
||||
|
||||
if eef_type == "gripper":
|
||||
remaining_ids = AgiBotWorld_TASK_TYPE["dexhand"]["task_ids"] + AgiBotWorld_TASK_TYPE["tactile"]["task_ids"]
|
||||
tasks = filter(lambda task: task[0].stem not in remaining_ids, tasks)
|
||||
else:
|
||||
tasks = filter(lambda task: task[0].stem in type_task_ids, tasks)
|
||||
|
||||
if task_ids:
|
||||
tasks = filter(lambda task: task[0].stem in task_ids, tasks)
|
||||
|
||||
if debug:
|
||||
save_as_lerobot_dataset(agibot_world_config, next(tasks), save_depth)
|
||||
else:
|
||||
runtime_env = RuntimeEnv(
|
||||
env_vars={"HDF5_USE_FILE_LOCKING": "FALSE", "HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE"}
|
||||
)
|
||||
ray.init(runtime_env=runtime_env)
|
||||
resources = ray.available_resources()
|
||||
cpus = int(resources["CPU"])
|
||||
|
||||
print(f"Available CPUs: {cpus}, num_cpus_per_task: {cpus_per_task}")
|
||||
|
||||
remote_task = ray.remote(save_as_lerobot_dataset).options(num_cpus=cpus_per_task)
|
||||
futures = []
|
||||
for task in tasks:
|
||||
futures.append((task[0].stem, remote_task.remote(agibot_world_config, task, save_depth)))
|
||||
|
||||
for task, future in futures:
|
||||
try:
|
||||
ray.get(future)
|
||||
except Exception as e:
|
||||
print(f"Exception occurred for {task}")
|
||||
with open("output.txt", "a") as f:
|
||||
f.write(f"{task}, exception details: {str(e)}\n")
|
||||
run_converter(
|
||||
adapter=adapter,
|
||||
executor=executor,
|
||||
cpus_per_task=cpus_per_task,
|
||||
tasks_per_job=tasks_per_job,
|
||||
workers=workers,
|
||||
resume_dir=resume_dir,
|
||||
debug=debug,
|
||||
local_repo_id=repo_id,
|
||||
hub_repo_id=repo_id,
|
||||
push_to_hub=push_to_hub,
|
||||
extra_tags=[eef_type],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--src-path", type=Path, required=True)
|
||||
parser.add_argument("--output-path", type=Path, required=True)
|
||||
parser.add_argument("--eef-type", type=str, choices=["gripper", "dexhand", "tactile"], default="gripper")
|
||||
parser.add_argument("--task-ids", type=str, nargs="+", help="task_327 task_351 ...", default=[])
|
||||
parser.add_argument("--cpus-per-task", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--eef-type",
|
||||
type=str,
|
||||
choices=["gripper", "dexhand", "tactile"],
|
||||
default="gripper",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-ids", type=str, nargs="+", help="task_327 task_351 ...", default=[]
|
||||
)
|
||||
parser.add_argument("--executor", type=str, choices=["local", "ray"], default="ray")
|
||||
parser.add_argument("--cpus-per-task", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--tasks-per-job",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of concurrent tasks per job, only used for ray",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes-per-task",
|
||||
type=int,
|
||||
default=10,
|
||||
help="number of AgiBot episodes grouped into one conversion task",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=-1, help="number of concurrent jobs to run"
|
||||
)
|
||||
parser.add_argument("--resume-dir", type=Path, help="logs directory to resume")
|
||||
parser.add_argument("--save-depth", action="store_true")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
args = parser.parse_args()
|
||||
parser.add_argument("--repo-id", type=str, help="required when push-to-hub is True")
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="upload to hub")
|
||||
return parser.parse_args()
|
||||
|
||||
main(**vars(args))
|
||||
|
||||
def cli():
|
||||
args = parse_args()
|
||||
module_name = "agibot2lerobot.agibot_h5"
|
||||
module = importlib.import_module(module_name)
|
||||
module.main(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
@@ -6,22 +6,67 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_task_info(task_json_path: str) -> dict:
|
||||
def get_task_info(task_json_path: str | Path) -> list[dict]:
|
||||
with open(task_json_path, "r") as f:
|
||||
task_info: list = json.load(f)
|
||||
task_info.sort(key=lambda episode: episode["episode_id"])
|
||||
return task_info
|
||||
|
||||
|
||||
def get_task_id(task_json_path: str | Path) -> str:
|
||||
return Path(task_json_path).stem.split("_")[-1]
|
||||
|
||||
|
||||
def get_episode_ids(src_path: str | Path, task_id: str | int) -> list[int]:
|
||||
observations_dir = Path(src_path) / "observations" / str(task_id)
|
||||
return sorted(
|
||||
int(path.name)
|
||||
for path in observations_dir.glob("*")
|
||||
if path.is_dir() and path.name.isdigit()
|
||||
)
|
||||
|
||||
|
||||
def get_episode_videos(
|
||||
src_path: str | Path,
|
||||
task_id: str | int,
|
||||
episode_id: int,
|
||||
agibot_world_config: dict,
|
||||
) -> dict[str, Path]:
|
||||
ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}"
|
||||
return {
|
||||
f"observation.images.{key}": ob_dir / "videos" / f"{key}_color.mp4"
|
||||
if "sensor" not in key
|
||||
else ob_dir / "tactile" / f"{key}.mp4" # HACK: handle tactile videos
|
||||
for key in agibot_world_config["images"]
|
||||
if "depth" not in key
|
||||
}
|
||||
|
||||
|
||||
def has_episode_videos(
|
||||
src_path: str | Path,
|
||||
task_id: str | int,
|
||||
episode_id: int,
|
||||
agibot_world_config: dict,
|
||||
) -> bool:
|
||||
videos = get_episode_videos(src_path, task_id, episode_id, agibot_world_config)
|
||||
return all(video_path.exists() for video_path in videos.values())
|
||||
|
||||
|
||||
def load_depths(root_dir: str, camera_name: str):
|
||||
cam_path = Path(root_dir)
|
||||
all_imgs = sorted(list(cam_path.glob(f"{camera_name}*")))
|
||||
return [np.array(Image.open(f)).astype(np.float32)[:, :, None] / 1000 for f in all_imgs]
|
||||
return [
|
||||
np.array(Image.open(f)).astype(np.float32)[:, :, None] / 1000 for f in all_imgs
|
||||
]
|
||||
|
||||
|
||||
def load_local_dataset(
|
||||
episode_id: int, src_path: str, task_id: int, save_depth: bool, AgiBotWorld_CONFIG: dict
|
||||
) -> tuple[list, dict]:
|
||||
episode_id: int,
|
||||
src_path: str | Path,
|
||||
task_id: str | int,
|
||||
save_depth: bool,
|
||||
AgiBotWorld_CONFIG: dict,
|
||||
) -> tuple[int, list[dict], dict[str, Path]] | None:
|
||||
"""Load local dataset and return a dict with observations and actions"""
|
||||
ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}"
|
||||
proprio_dir = Path(src_path) / f"proprio_stats/{task_id}/{episode_id}"
|
||||
@@ -30,9 +75,13 @@ def load_local_dataset(
|
||||
action = {}
|
||||
with h5py.File(proprio_dir / "proprio_stats.h5", "r") as f:
|
||||
for key in AgiBotWorld_CONFIG["states"]:
|
||||
state[f"observation.states.{key}"] = np.array(f["state/" + key.replace(".", "/")], dtype=np.float32)
|
||||
state[f"observation.states.{key}"] = np.array(
|
||||
f["state/" + key.replace(".", "/")], dtype=np.float32
|
||||
)
|
||||
for key in AgiBotWorld_CONFIG["actions"]:
|
||||
action[f"actions.{key}"] = np.array(f["action/" + key.replace(".", "/")], dtype=np.float32)
|
||||
action[f"actions.{key}"] = np.array(
|
||||
f["action/" + key.replace(".", "/")], dtype=np.float32
|
||||
)
|
||||
|
||||
# HACK: agibot team forgot to pad or filter some of the values
|
||||
num_frames = len(next(iter(state.values())))
|
||||
@@ -42,7 +91,10 @@ def load_local_dataset(
|
||||
elif len(action_value) < num_frames:
|
||||
state_key = action_key.replace("actions", "state").replace(".", "/")
|
||||
new_action_value = np.array(f[state_key], dtype=np.float32).copy()
|
||||
action_index_key = "/".join(list(action_key.replace("actions", "action").split(".")[:-1]) + ["index"])
|
||||
action_index_key = "/".join(
|
||||
list(action_key.replace("actions", "action").split(".")[:-1])
|
||||
+ ["index"]
|
||||
)
|
||||
action_index = np.array(f[action_index_key])
|
||||
# agibot lost end index, replace it with joint
|
||||
if not action_index.size:
|
||||
@@ -52,11 +104,13 @@ def load_local_dataset(
|
||||
action[action_key] = new_action_value
|
||||
elif len(action_value) > num_frames:
|
||||
print("corrupt data, skipping")
|
||||
return episode_id, [], {"dummy_video": Path("/path/to/no_exist")}
|
||||
return None
|
||||
|
||||
if save_depth:
|
||||
depth_imgs = load_depths(ob_dir / "depth", "head_depth")
|
||||
assert num_frames == len(depth_imgs), "Number of images and states are not equal"
|
||||
assert num_frames == len(depth_imgs), (
|
||||
"Number of images and states are not equal"
|
||||
)
|
||||
|
||||
state_key_prefix_len = len("observation.states.")
|
||||
action_key_prefix_len = len("actions.")
|
||||
@@ -68,7 +122,9 @@ def load_local_dataset(
|
||||
if value.size
|
||||
else np.zeros(
|
||||
AgiBotWorld_CONFIG["states"][key[state_key_prefix_len:]]["shape"],
|
||||
dtype=AgiBotWorld_CONFIG["states"][key[state_key_prefix_len:]]["dtype"],
|
||||
dtype=AgiBotWorld_CONFIG["states"][key[state_key_prefix_len:]][
|
||||
"dtype"
|
||||
],
|
||||
)
|
||||
for key, value in state.items()
|
||||
},
|
||||
@@ -77,7 +133,9 @@ def load_local_dataset(
|
||||
if value.size
|
||||
else np.zeros(
|
||||
AgiBotWorld_CONFIG["actions"][key[action_key_prefix_len:]]["shape"],
|
||||
dtype=AgiBotWorld_CONFIG["actions"][key[action_key_prefix_len:]]["dtype"],
|
||||
dtype=AgiBotWorld_CONFIG["actions"][key[action_key_prefix_len:]][
|
||||
"dtype"
|
||||
],
|
||||
)
|
||||
for key, value in action.items()
|
||||
},
|
||||
@@ -85,11 +143,5 @@ def load_local_dataset(
|
||||
for i in range(num_frames)
|
||||
]
|
||||
|
||||
videos = {
|
||||
f"observation.images.{key}": ob_dir / "videos" / f"{key}_color.mp4"
|
||||
if "sensor" not in key
|
||||
else ob_dir / "tactile" / f"{key}.mp4" # HACK: handle tactile videos
|
||||
for key in AgiBotWorld_CONFIG["images"]
|
||||
if "depth" not in key
|
||||
}
|
||||
videos = get_episode_videos(src_path, task_id, episode_id, AgiBotWorld_CONFIG)
|
||||
return episode_id, frames, videos
|
||||
|
||||
Reference in New Issue
Block a user