Refactor AgiBot converter onto generic pipeline

Co-authored-by: Codex <codex@openai.com>
This commit is contained in:
Tavish
2026-06-22 13:12:58 +08:00
parent 2baee72741
commit 8afce7fa23
3 changed files with 414 additions and 148 deletions
+14 -4
View File
@@ -30,6 +30,7 @@ In this dataset, we have made several key improvements:
- **Preservation of Agibots Original Information** 🧠: We have preserved as much of Agibots original information as possible, with field names strictly adhering to the original datasets naming conventions to ensure compatibility and consistency.
- **State and Action as Dictionaries** 🧾: The traditional one-dimensional state and action have been transformed into dictionaries, allowing for greater flexibility in designing custom states and actions, enabling modular and scalable handling.
- **Generic Conversion Pipeline**: Conversion now uses the shared `generic_converter` execution flow for local/Ray DataTrove execution, resumable logs, temporary per-task datasets, final aggregation, and optional Hub upload.
Dataset Structure of `meta/info.json`:
@@ -116,10 +117,11 @@ Dataset Structure of `meta/info.json`:
Follow instructions in [official repo](https://github.com/huggingface/lerobot?tab=readme-ov-file#installation).
2. Install others:
We use ray for parallel conversion, significantly speeding up data processing tasks by distributing the workload across multiple cores or nodes (if any).
We use DataTrove for conversion. Install the Ray extra if you want distributed execution across multiple cores or nodes.
```bash
pip install h5py
pip install -U "ray[default]"
pip install -U datatrove
pip install -U "datatrove[ray]" # optional, for --executor ray
```
## Get started
@@ -161,12 +163,20 @@ git clone https://github.com/Tavish9/any4lerobot.git
There are three types of end-effector, `gripper`, `dexhand` and `tactile`, specify the type before converting
`--output-path` is the final aggregated LeRobot dataset root. Temporary per-task
datasets are written next to it under `<output-name>_temp` and removed after
aggregation.
`--episodes-per-task` controls AgiBot conversion granularity. The default `1`
creates one temporary dataset per raw episode for better Ray load balancing.
```bash
python convert.py \
python agibot_h5.py \
--src-path /path/to/AgiBotWorld-Beta \
--output-path /path/to/local \
--eef-type gripper \
--num-cpus-per-task 3
--cpus-per-task 3 \
--episodes-per-task 1
```
### Execute the script:
+330 -126
View File
@@ -1,22 +1,47 @@
import argparse
import importlib
import inspect
import shutil
import sys
import tempfile
from collections.abc import Sequence
from pathlib import Path
from typing import Any
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import ray
import torch
from agibot_utils.agibot_utils import get_task_info, load_local_dataset
from agibot_utils.config import AgiBotWorld_TASK_TYPE
from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config
from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.dataset_writer import DatasetWriter
from lerobot.datasets.feature_utils import get_hf_features_from_features, validate_episode_buffer, validate_frame
from lerobot.datasets.feature_utils import (
get_hf_features_from_features,
validate_episode_buffer,
validate_frame,
)
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
from ray.runtime_env import RuntimeEnv
AGIBOT_DIR = Path(__file__).resolve().parent
REPO_ROOT = AGIBOT_DIR.parent
for import_path in (REPO_ROOT, AGIBOT_DIR):
import_path_str = str(import_path)
if import_path_str not in sys.path:
sys.path.insert(0, import_path_str)
from agibot_utils.agibot_utils import ( # noqa: E402
get_episode_ids,
get_task_id,
get_task_info,
has_episode_videos,
load_local_dataset,
)
from agibot_utils.config import AgiBotWorld_TASK_TYPE # noqa: E402
from agibot_utils.lerobot_utils import ( # noqa: E402
compute_episode_stats,
generate_features_from_config,
)
from generic_converter import BaseAdapter, ConversionTask, run_converter # noqa: E402
class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
@@ -33,7 +58,9 @@ class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
# Extract value and serialize numpy arrays
# because PyArrow's from_pydict function doesn't support numpy arrays
val = value[0] if isinstance(value, list) else value
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
combined_dict[key].append(
val.tolist() if isinstance(val, np.ndarray) else val
)
first_ep = self._metadata_buffer[0]
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
@@ -43,9 +70,16 @@ class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
table = pa.Table.from_pydict(combined_dict, schema=schema)
if not self._pq_writer:
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
path = Path(
self.root
/ DEFAULT_EPISODES_PATH.format(
chunk_index=chunk_idx, file_index=file_idx
)
)
path.parent.mkdir(parents=True, exist_ok=True)
self._pq_writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
self._pq_writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True
)
self._pq_writer.write_table(table)
@@ -75,7 +109,9 @@ class AgiBotDatasetWriter(DatasetWriter):
frame[name] = frame[name].numpy()
features = {
key: value for key, value in self._meta.features.items() if key in self.hf_features
key: value
for key, value in self._meta.features.items()
if key in self.hf_features
} # remove video keys
validate_frame(frame, features)
@@ -101,12 +137,20 @@ class AgiBotDatasetWriter(DatasetWriter):
self.episode_buffer["size"] += 1
def save_episode(
self, videos: dict, action_config: list, episode_data: dict | None = None, parallel_encoding: bool = True
self,
videos: dict,
action_config: list,
episode_data: dict | None = None,
parallel_encoding: bool = True,
) -> None:
"""Save the current episode in self.episode_buffer to disk."""
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
episode_buffer = (
episode_data if episode_data is not None else self.episode_buffer
)
validate_episode_buffer(episode_buffer, self._meta.total_episodes, self._meta.features)
validate_episode_buffer(
episode_buffer, self._meta.total_episodes, self._meta.features
)
# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size")
@@ -114,19 +158,25 @@ class AgiBotDatasetWriter(DatasetWriter):
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange(self._meta.total_frames, self._meta.total_frames + episode_length)
episode_buffer["index"] = np.arange(
self._meta.total_frames, self._meta.total_frames + episode_length
)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Update tasks and task indices with new tasks if any
self._meta.save_episode_tasks(episode_tasks)
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self._meta.get_task_index(task) for task in tasks])
episode_buffer["task_index"] = np.array(
[self._meta.get_task_index(task) for task in tasks]
)
for key, ft in self._meta.features.items():
# index, episode_index, task_index are already processed above, and image and video
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["video"]:
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [
"video"
]:
continue
episode_buffer[key] = np.stack(episode_buffer[key]).squeeze()
@@ -145,7 +195,9 @@ class AgiBotDatasetWriter(DatasetWriter):
ep_metadata.update(self._save_episode_video(video_key, episode_index))
ep_metadata.update({"action_config": action_config})
self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
self._meta.save_episode(
episode_index, episode_length, episode_tasks, ep_stats, ep_metadata
)
if has_video_keys and use_batched_encoding:
self._episodes_since_last_encoding += 1
@@ -158,13 +210,18 @@ class AgiBotDatasetWriter(DatasetWriter):
if not episode_data:
self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0)
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
def _encode_temporary_episode_video(
self, video_key: str, episode_index: int
) -> Path:
"""
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
temp_path = Path(tempfile.mkdtemp(dir=self._root)) / f"{video_key}_{episode_index:03d}.mp4"
temp_path = (
Path(tempfile.mkdtemp(dir=self._root))
/ f"{video_key}_{episode_index:03d}.mp4"
)
shutil.copy(self.current_videos[video_key], temp_path)
return temp_path
@@ -188,153 +245,300 @@ class AgiBotDataset(LeRobotDataset):
root=params["root"],
use_videos=params["use_videos"],
metadata_buffer_size=params["metadata_buffer_size"],
video_files_size_in_mb=params["video_files_size_in_mb"],
data_files_size_in_mb=params["data_files_size_in_mb"],
)
obj.writer: AgiBotDatasetWriter = AgiBotDatasetWriter(
meta=obj.meta,
root=obj.root,
vcodec=obj._vcodec,
encoder_threads=obj._encoder_threads,
batch_encoding_size=obj._batch_encoding_size,
camera_encoder=obj.writer._camera_encoder,
encoder_threads=obj.writer._encoder_threads,
batch_encoding_size=obj.writer._batch_encoding_size,
streaming_encoder=obj.writer._streaming_encoder,
)
return obj
def save_episode(
self, videos: dict, action_config: list, episode_data: dict | None = None, parallel_encoding: bool = True
self,
videos: dict,
action_config: list,
episode_data: dict | None = None,
parallel_encoding: bool = True,
) -> None:
self._require_writer("save_episode")
self.writer.save_episode(videos, action_config, episode_data, parallel_encoding)
def get_all_tasks(src_path: Path, output_path: Path):
json_files = src_path.glob("task_info/*.json")
for json_file in json_files:
local_dir = output_path / "agibotworld" / json_file.stem
yield (json_file, local_dir.resolve())
class AgiBotAdapter(BaseAdapter):
dataset_type = "agibot"
fps = 30
robot_type = "a2d"
tags = ("agibot-world", "a2d")
def __init__(
self,
src_path: Path,
output_path: Path,
eef_type: str,
task_ids: Sequence[str],
save_depth: bool,
episodes_per_task: int,
):
super().__init__(output_path)
if episodes_per_task < 1:
raise ValueError("--episodes-per-task must be >= 1")
self.src_path = src_path.expanduser().resolve()
self.eef_type = eef_type
self.task_ids = set(task_ids)
self.save_depth = save_depth
self.episodes_per_task = episodes_per_task
self.agibot_world_config = AgiBotWorld_TASK_TYPE[eef_type]["task_config"]
self.type_task_ids = set(AgiBotWorld_TASK_TYPE[eef_type]["task_ids"])
self.features = generate_features_from_config(self.agibot_world_config)
if not save_depth:
self.features.pop("observation.images.head_depth", None)
def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], save_depth):
json_file, local_dir = task
print(f"processing {json_file.stem}, saving to {local_dir}")
src_path = json_file.parent.parent
task_info = get_task_info(json_file)
task_name = task_info[0]["task_name"]
task_init_scene = task_info[0]["init_scene_text"]
task_instruction = f"{task_name} | {task_init_scene}"
task_id = json_file.stem.split("_")[-1]
task_info = {episode["episode_id"]: episode for episode in task_info}
def load_tasks(self) -> list[ConversionTask]:
tasks = []
for json_file in sorted(self.src_path.glob("task_info/*.json")):
if not self._include_task(json_file.stem):
continue
_, _, _, task_info = self._load_task_context(json_file)
episode_ids = sorted(task_info)
features = generate_features_from_config(agibot_world_config)
for chunk_index, chunk_episode_ids in enumerate(
self._chunk_episode_ids(episode_ids)
):
task_name = self._format_task_name(
json_file.stem, chunk_index, chunk_episode_ids
)
tasks.append(
ConversionTask(
input_path=json_file.resolve(),
output_path=(
self.temp_output_path / "agibotworld" / task_name
).resolve(),
local_repo_id=task_name,
metadata={"episode_ids": tuple(chunk_episode_ids)},
)
)
return tasks
if local_dir.exists():
shutil.rmtree(local_dir)
if not save_depth:
features.pop("observation.images.head_depth")
dataset: AgiBotDataset = AgiBotDataset.create(
repo_id=json_file.stem,
root=local_dir,
fps=30,
robot_type="a2d",
features=features,
)
all_subdir = [f.as_posix() for f in src_path.glob(f"observations/{task_id}/*") if f.is_dir()]
all_subdir_eids = sorted([int(Path(path).name) for path in all_subdir])
for eid in all_subdir_eids:
if eid not in task_info:
print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...")
continue
action_config = task_info[eid]["label_info"]["action_config"]
raw_dataset = load_local_dataset(
eid,
src_path=src_path,
task_id=task_id,
save_depth=save_depth,
AgiBotWorld_CONFIG=agibot_world_config,
def load_subset(self, task: ConversionTask):
json_file = task.input_path
print(f"processing {json_file.stem}, saving to {task.output_path}")
src_path, task_id, task_instruction, task_info = self._load_task_context(
json_file
)
_, frames, videos = raw_dataset
if not all([video_path.exists() for video_path in videos.values()]):
print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping...")
continue
for frame_data in frames:
frame_data["task"] = task_instruction
task_episode_ids = task.metadata.get("episode_ids")
if task_episode_ids is None:
task_episode_ids = get_episode_ids(src_path, task_id)
for eid in task_episode_ids:
if not self._is_convertible_episode(
json_file.stem, src_path, task_id, eid, task_info
):
continue
action_config = task_info[eid]["label_info"]["action_config"]
raw_dataset = load_local_dataset(
eid,
src_path=src_path,
task_id=task_id,
save_depth=self.save_depth,
AgiBotWorld_CONFIG=self.agibot_world_config,
)
if raw_dataset is None:
continue
_, frames, videos = raw_dataset
for frame_data in frames:
frame_data["task"] = task_instruction
yield {
"episode_id": eid,
"frames": frames,
"videos": videos,
"action_config": action_config,
}
def create_dataset(self, task: ConversionTask) -> AgiBotDataset:
return AgiBotDataset.create(
repo_id=task.local_repo_id,
root=task.output_path,
fps=self.fps,
robot_type=self.robot_type,
features=self.features,
)
def save_episode(
self,
dataset: AgiBotDataset,
episode_data: dict[str, Any],
task: ConversionTask,
) -> bool:
for frame_data in episode_data["frames"]:
dataset.add_frame(frame_data)
try:
dataset.save_episode(videos=videos, action_config=action_config)
dataset.save_episode(
videos=episode_data["videos"],
action_config=episode_data["action_config"],
)
except Exception as e:
print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}")
print(
f"{task.input_path.stem}, episode_{episode_data['episode_id']}: "
f"there are some corrupted mp4s\nException details: {str(e)}"
)
dataset.clear_episode_buffer(delete_images=False)
continue
return False
return True
print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}")
def get_episode_length(self, episode_data: dict[str, Any]) -> int:
return len(episode_data["frames"])
dataset.finalize()
def _chunk_episode_ids(self, episode_ids: list[int]):
for start in range(0, len(episode_ids), self.episodes_per_task):
yield episode_ids[start : start + self.episodes_per_task]
def _format_task_name(
self, task_name: str, chunk_index: int, episode_ids: Sequence[int]
) -> str:
if len(episode_ids) == 1:
return f"{task_name}_episode_{episode_ids[0]}"
return (
f"{task_name}_chunk_{chunk_index:06d}_episodes_"
f"{episode_ids[0]}_{episode_ids[-1]}"
)
def _load_task_context(
self, json_file: Path
) -> tuple[Path, str, str, dict[int, dict[str, Any]]]:
src_path = json_file.parent.parent
task_id = get_task_id(json_file)
task_info = get_task_info(json_file)
task_name = task_info[0]["task_name"]
task_init_scene = task_info[0]["init_scene_text"]
task_instruction = f"{task_name} | {task_init_scene}"
task_info_by_episode = {episode["episode_id"]: episode for episode in task_info}
return src_path, task_id, task_instruction, task_info_by_episode
def _is_convertible_episode(
self,
task_name: str,
src_path: Path,
task_id: str,
episode_id: int,
task_info: dict[int, dict[str, Any]],
) -> bool:
if episode_id not in task_info:
print(
f"{task_name}, episode_{episode_id} not in task_info.json, skipping..."
)
return False
if not has_episode_videos(
src_path, task_id, episode_id, self.agibot_world_config
):
print(
f"{task_name}, episode_{episode_id}: "
"some of the videos does not exist, skipping..."
)
return False
return True
def _include_task(self, task_id: str) -> bool:
if self.task_ids and task_id not in self.task_ids:
return False
if self.eef_type == "gripper":
remaining_ids = set(AgiBotWorld_TASK_TYPE["dexhand"]["task_ids"])
remaining_ids.update(AgiBotWorld_TASK_TYPE["tactile"]["task_ids"])
return task_id not in remaining_ids
return task_id in self.type_task_ids
def main(
src_path: str,
output_path: str,
src_path: Path,
output_path: Path,
eef_type: str,
task_ids: list,
task_ids: list[str],
executor: str,
cpus_per_task: int,
tasks_per_job: int,
workers: int,
save_depth: bool,
episodes_per_task: int,
resume_dir: Path | None = None,
debug: bool = False,
repo_id: str | None = None,
push_to_hub: bool = False,
):
tasks = get_all_tasks(src_path, output_path)
agibot_world_config, type_task_ids = (
AgiBotWorld_TASK_TYPE[eef_type]["task_config"],
AgiBotWorld_TASK_TYPE[eef_type]["task_ids"],
adapter = AgiBotAdapter(
src_path=src_path,
output_path=output_path,
eef_type=eef_type,
task_ids=task_ids,
save_depth=save_depth,
episodes_per_task=episodes_per_task,
)
if eef_type == "gripper":
remaining_ids = AgiBotWorld_TASK_TYPE["dexhand"]["task_ids"] + AgiBotWorld_TASK_TYPE["tactile"]["task_ids"]
tasks = filter(lambda task: task[0].stem not in remaining_ids, tasks)
else:
tasks = filter(lambda task: task[0].stem in type_task_ids, tasks)
if task_ids:
tasks = filter(lambda task: task[0].stem in task_ids, tasks)
if debug:
save_as_lerobot_dataset(agibot_world_config, next(tasks), save_depth)
else:
runtime_env = RuntimeEnv(
env_vars={"HDF5_USE_FILE_LOCKING": "FALSE", "HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE"}
)
ray.init(runtime_env=runtime_env)
resources = ray.available_resources()
cpus = int(resources["CPU"])
print(f"Available CPUs: {cpus}, num_cpus_per_task: {cpus_per_task}")
remote_task = ray.remote(save_as_lerobot_dataset).options(num_cpus=cpus_per_task)
futures = []
for task in tasks:
futures.append((task[0].stem, remote_task.remote(agibot_world_config, task, save_depth)))
for task, future in futures:
try:
ray.get(future)
except Exception as e:
print(f"Exception occurred for {task}")
with open("output.txt", "a") as f:
f.write(f"{task}, exception details: {str(e)}\n")
run_converter(
adapter=adapter,
executor=executor,
cpus_per_task=cpus_per_task,
tasks_per_job=tasks_per_job,
workers=workers,
resume_dir=resume_dir,
debug=debug,
local_repo_id=repo_id,
hub_repo_id=repo_id,
push_to_hub=push_to_hub,
extra_tags=[eef_type],
)
if __name__ == "__main__":
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--src-path", type=Path, required=True)
parser.add_argument("--output-path", type=Path, required=True)
parser.add_argument("--eef-type", type=str, choices=["gripper", "dexhand", "tactile"], default="gripper")
parser.add_argument("--task-ids", type=str, nargs="+", help="task_327 task_351 ...", default=[])
parser.add_argument("--cpus-per-task", type=int, default=3)
parser.add_argument(
"--eef-type",
type=str,
choices=["gripper", "dexhand", "tactile"],
default="gripper",
)
parser.add_argument(
"--task-ids", type=str, nargs="+", help="task_327 task_351 ...", default=[]
)
parser.add_argument("--executor", type=str, choices=["local", "ray"], default="ray")
parser.add_argument("--cpus-per-task", type=int, default=1)
parser.add_argument(
"--tasks-per-job",
type=int,
default=1,
help="number of concurrent tasks per job, only used for ray",
)
parser.add_argument(
"--episodes-per-task",
type=int,
default=10,
help="number of AgiBot episodes grouped into one conversion task",
)
parser.add_argument(
"--workers", type=int, default=-1, help="number of concurrent jobs to run"
)
parser.add_argument("--resume-dir", type=Path, help="logs directory to resume")
parser.add_argument("--save-depth", action="store_true")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
parser.add_argument("--repo-id", type=str, help="required when push-to-hub is True")
parser.add_argument("--push-to-hub", action="store_true", help="upload to hub")
return parser.parse_args()
main(**vars(args))
def cli():
args = parse_args()
module_name = "agibot2lerobot.agibot_h5"
module = importlib.import_module(module_name)
module.main(**vars(args))
if __name__ == "__main__":
cli()
+70 -18
View File
@@ -6,22 +6,67 @@ import numpy as np
from PIL import Image
def get_task_info(task_json_path: str) -> dict:
def get_task_info(task_json_path: str | Path) -> list[dict]:
with open(task_json_path, "r") as f:
task_info: list = json.load(f)
task_info.sort(key=lambda episode: episode["episode_id"])
return task_info
def get_task_id(task_json_path: str | Path) -> str:
return Path(task_json_path).stem.split("_")[-1]
def get_episode_ids(src_path: str | Path, task_id: str | int) -> list[int]:
observations_dir = Path(src_path) / "observations" / str(task_id)
return sorted(
int(path.name)
for path in observations_dir.glob("*")
if path.is_dir() and path.name.isdigit()
)
def get_episode_videos(
src_path: str | Path,
task_id: str | int,
episode_id: int,
agibot_world_config: dict,
) -> dict[str, Path]:
ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}"
return {
f"observation.images.{key}": ob_dir / "videos" / f"{key}_color.mp4"
if "sensor" not in key
else ob_dir / "tactile" / f"{key}.mp4" # HACK: handle tactile videos
for key in agibot_world_config["images"]
if "depth" not in key
}
def has_episode_videos(
src_path: str | Path,
task_id: str | int,
episode_id: int,
agibot_world_config: dict,
) -> bool:
videos = get_episode_videos(src_path, task_id, episode_id, agibot_world_config)
return all(video_path.exists() for video_path in videos.values())
def load_depths(root_dir: str, camera_name: str):
cam_path = Path(root_dir)
all_imgs = sorted(list(cam_path.glob(f"{camera_name}*")))
return [np.array(Image.open(f)).astype(np.float32)[:, :, None] / 1000 for f in all_imgs]
return [
np.array(Image.open(f)).astype(np.float32)[:, :, None] / 1000 for f in all_imgs
]
def load_local_dataset(
episode_id: int, src_path: str, task_id: int, save_depth: bool, AgiBotWorld_CONFIG: dict
) -> tuple[list, dict]:
episode_id: int,
src_path: str | Path,
task_id: str | int,
save_depth: bool,
AgiBotWorld_CONFIG: dict,
) -> tuple[int, list[dict], dict[str, Path]] | None:
"""Load local dataset and return a dict with observations and actions"""
ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}"
proprio_dir = Path(src_path) / f"proprio_stats/{task_id}/{episode_id}"
@@ -30,9 +75,13 @@ def load_local_dataset(
action = {}
with h5py.File(proprio_dir / "proprio_stats.h5", "r") as f:
for key in AgiBotWorld_CONFIG["states"]:
state[f"observation.states.{key}"] = np.array(f["state/" + key.replace(".", "/")], dtype=np.float32)
state[f"observation.states.{key}"] = np.array(
f["state/" + key.replace(".", "/")], dtype=np.float32
)
for key in AgiBotWorld_CONFIG["actions"]:
action[f"actions.{key}"] = np.array(f["action/" + key.replace(".", "/")], dtype=np.float32)
action[f"actions.{key}"] = np.array(
f["action/" + key.replace(".", "/")], dtype=np.float32
)
# HACK: agibot team forgot to pad or filter some of the values
num_frames = len(next(iter(state.values())))
@@ -42,7 +91,10 @@ def load_local_dataset(
elif len(action_value) < num_frames:
state_key = action_key.replace("actions", "state").replace(".", "/")
new_action_value = np.array(f[state_key], dtype=np.float32).copy()
action_index_key = "/".join(list(action_key.replace("actions", "action").split(".")[:-1]) + ["index"])
action_index_key = "/".join(
list(action_key.replace("actions", "action").split(".")[:-1])
+ ["index"]
)
action_index = np.array(f[action_index_key])
# agibot lost end index, replace it with joint
if not action_index.size:
@@ -52,11 +104,13 @@ def load_local_dataset(
action[action_key] = new_action_value
elif len(action_value) > num_frames:
print("corrupt data, skipping")
return episode_id, [], {"dummy_video": Path("/path/to/no_exist")}
return None
if save_depth:
depth_imgs = load_depths(ob_dir / "depth", "head_depth")
assert num_frames == len(depth_imgs), "Number of images and states are not equal"
assert num_frames == len(depth_imgs), (
"Number of images and states are not equal"
)
state_key_prefix_len = len("observation.states.")
action_key_prefix_len = len("actions.")
@@ -68,7 +122,9 @@ def load_local_dataset(
if value.size
else np.zeros(
AgiBotWorld_CONFIG["states"][key[state_key_prefix_len:]]["shape"],
dtype=AgiBotWorld_CONFIG["states"][key[state_key_prefix_len:]]["dtype"],
dtype=AgiBotWorld_CONFIG["states"][key[state_key_prefix_len:]][
"dtype"
],
)
for key, value in state.items()
},
@@ -77,7 +133,9 @@ def load_local_dataset(
if value.size
else np.zeros(
AgiBotWorld_CONFIG["actions"][key[action_key_prefix_len:]]["shape"],
dtype=AgiBotWorld_CONFIG["actions"][key[action_key_prefix_len:]]["dtype"],
dtype=AgiBotWorld_CONFIG["actions"][key[action_key_prefix_len:]][
"dtype"
],
)
for key, value in action.items()
},
@@ -85,11 +143,5 @@ def load_local_dataset(
for i in range(num_frames)
]
videos = {
f"observation.images.{key}": ob_dir / "videos" / f"{key}_color.mp4"
if "sensor" not in key
else ob_dir / "tactile" / f"{key}.mp4" # HACK: handle tactile videos
for key in AgiBotWorld_CONFIG["images"]
if "depth" not in key
}
videos = get_episode_videos(src_path, task_id, episode_id, AgiBotWorld_CONFIG)
return episode_id, frames, videos