mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-27 19:49:41 +00:00
⬆️ upgrade lerobot dataset to v3.0 (#60)
* update openx * update libero * update agibot * update robomind * update readme * update return type
This commit is contained in:
+72
-201
@@ -1,10 +1,7 @@
|
||||
import argparse
|
||||
import gc
|
||||
import shutil
|
||||
from concurrent.futures import (
|
||||
ThreadPoolExecutor,
|
||||
as_completed,
|
||||
)
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -13,105 +10,13 @@ import torch
|
||||
from agibot_utils.agibot_utils import get_task_info, load_local_dataset
|
||||
from agibot_utils.config import AgiBotWorld_TASK_TYPE
|
||||
from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
check_timestamps_sync,
|
||||
get_episode_data_index,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
)
|
||||
from lerobot.datasets.video_utils import get_safe_default_codec
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import validate_episode_buffer, validate_frame
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
|
||||
|
||||
class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
|
||||
def save_episode(
|
||||
self,
|
||||
episode_index: int,
|
||||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
action_config: list[dict],
|
||||
) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
"action_config": action_config,
|
||||
}
|
||||
self.episodes[episode_index] = episode_dict
|
||||
write_episode(episode_dict, self.root)
|
||||
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
|
||||
|
||||
class AgiBotDataset(LeRobotDataset):
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
root: str | Path | None = None,
|
||||
robot_type: str | None = None,
|
||||
use_videos: bool = True,
|
||||
tolerance_s: float = 1e-4,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = AgiBotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
obj.revision = None
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj.create_episode_buffer()
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = obj.create_hf_dataset()
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
return obj
|
||||
|
||||
def add_frame(self, frame: dict, task: str, timestamp: float | None = None) -> None:
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
@@ -130,11 +35,10 @@ class AgiBotDataset(LeRobotDataset):
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
if timestamp is None:
|
||||
timestamp = frame_index / self.fps
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
self.episode_buffer["task"].append(task)
|
||||
self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key, value in frame.items():
|
||||
@@ -156,8 +60,7 @@ class AgiBotDataset(LeRobotDataset):
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
"""
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
|
||||
@@ -170,11 +73,8 @@ class AgiBotDataset(LeRobotDataset):
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Add new tasks to the tasks dictionary
|
||||
for task in episode_tasks:
|
||||
task_index = self.meta.get_task_index(task)
|
||||
if task_index is None:
|
||||
self.meta.add_task(task)
|
||||
# Update tasks and task indices with new tasks if any
|
||||
self.meta.save_episode_tasks(episode_tasks)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
@@ -187,31 +87,46 @@ class AgiBotDataset(LeRobotDataset):
|
||||
episode_buffer[key] = np.stack(episode_buffer[key]).squeeze()
|
||||
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
episode_buffer[key] = str(video_path) # PosixPath -> str
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(videos[key], video_path)
|
||||
episode_buffer[key] = str(videos[key]) # PosixPath -> str
|
||||
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
self.current_videos = videos
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# `meta.save_episode` be executed after encoding the videos
|
||||
# add action_config to current episode
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, action_config)
|
||||
ep_metadata.update({"action_config": action_config})
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
|
||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
check_timestamps_sync(
|
||||
episode_buffer["timestamp"],
|
||||
episode_buffer["episode_index"],
|
||||
ep_data_index_np,
|
||||
self.fps,
|
||||
self.tolerance_s,
|
||||
)
|
||||
if has_video_keys and use_batched_encoding:
|
||||
# Check if we should trigger batch encoding
|
||||
self.episodes_since_last_encoding += 1
|
||||
if self.episodes_since_last_encoding == self.batch_encoding_size:
|
||||
start_ep = self.num_episodes - self.batch_encoding_size
|
||||
end_ep = self.num_episodes
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
self.episodes_since_last_encoding = 0
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
if not episode_data:
|
||||
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
|
||||
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
shutil.copy(self.current_videos[video_key], temp_path)
|
||||
return temp_path
|
||||
|
||||
|
||||
def get_all_tasks(src_path: Path, output_path: Path):
|
||||
@@ -221,7 +136,7 @@ def get_all_tasks(src_path: Path, output_path: Path):
|
||||
yield (json_file, local_dir.resolve())
|
||||
|
||||
|
||||
def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], num_threads, save_depth, debug):
|
||||
def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], save_depth):
|
||||
json_file, local_dir = task
|
||||
print(f"processing {json_file.stem}, saving to {local_dir}")
|
||||
src_path = json_file.parent.parent
|
||||
@@ -252,70 +167,34 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], num_th
|
||||
|
||||
all_subdir_eids = sorted([int(Path(path).name) for path in all_subdir])
|
||||
|
||||
if debug or not save_depth:
|
||||
for eid in all_subdir_eids:
|
||||
if eid not in task_info:
|
||||
print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...")
|
||||
continue
|
||||
action_config = task_info[eid]["label_info"]["action_config"]
|
||||
raw_dataset = load_local_dataset(
|
||||
eid,
|
||||
src_path=src_path,
|
||||
task_id=task_id,
|
||||
save_depth=save_depth,
|
||||
AgiBotWorld_CONFIG=agibot_world_config,
|
||||
)
|
||||
_, frames, videos = raw_dataset
|
||||
if not all([video_path.exists() for video_path in videos.values()]):
|
||||
print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping...")
|
||||
continue
|
||||
for eid in all_subdir_eids:
|
||||
if eid not in task_info:
|
||||
print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...")
|
||||
continue
|
||||
action_config = task_info[eid]["label_info"]["action_config"]
|
||||
raw_dataset = load_local_dataset(
|
||||
eid,
|
||||
src_path=src_path,
|
||||
task_id=task_id,
|
||||
save_depth=save_depth,
|
||||
AgiBotWorld_CONFIG=agibot_world_config,
|
||||
)
|
||||
_, frames, videos = raw_dataset
|
||||
if not all([video_path.exists() for video_path in videos.values()]):
|
||||
print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping...")
|
||||
continue
|
||||
|
||||
for frame_data in frames:
|
||||
dataset.add_frame(frame_data, task_instruction)
|
||||
try:
|
||||
dataset.save_episode(videos=videos, action_config=action_config)
|
||||
except Exception as e:
|
||||
print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}")
|
||||
dataset.episode_buffer = None
|
||||
continue
|
||||
gc.collect()
|
||||
print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}")
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = []
|
||||
for eid in all_subdir_eids:
|
||||
if eid not in task_info:
|
||||
print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...")
|
||||
continue
|
||||
futures.append(
|
||||
executor.submit(
|
||||
load_local_dataset,
|
||||
eid,
|
||||
src_path=src_path,
|
||||
task_id=task_id,
|
||||
save_depth=save_depth,
|
||||
AgiBotWorld_CONFIG=agibot_world_config,
|
||||
)
|
||||
)
|
||||
|
||||
for raw_dataset in as_completed(futures):
|
||||
eid, frames, videos = raw_dataset.result()
|
||||
if not all([video_path.exists() for video_path in videos.values()]):
|
||||
print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping...")
|
||||
continue
|
||||
action_config = task_info[eid]["label_info"]["action_config"]
|
||||
for frame_data in frames:
|
||||
dataset.add_frame(frame_data, task_instruction)
|
||||
try:
|
||||
dataset.save_episode(videos=videos, action_config=action_config)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}"
|
||||
)
|
||||
dataset.episode_buffer = None
|
||||
continue
|
||||
gc.collect()
|
||||
print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}")
|
||||
for frame_data in frames:
|
||||
frame_data["task"] = task_instruction
|
||||
dataset.add_frame(frame_data)
|
||||
try:
|
||||
dataset.save_episode(videos=videos, action_config=action_config)
|
||||
except Exception as e:
|
||||
print(f"{json_file.stem}, episode_{eid}: there are some corrupted mp4s\nException details: {str(e)}")
|
||||
dataset.episode_buffer = None
|
||||
continue
|
||||
gc.collect()
|
||||
print(f"process done for {json_file.stem}, episode_id {eid}, len {len(frames)}")
|
||||
|
||||
|
||||
def main(
|
||||
@@ -324,7 +203,6 @@ def main(
|
||||
eef_type: str,
|
||||
task_ids: list,
|
||||
cpus_per_task: int,
|
||||
num_threads_per_task: int,
|
||||
save_depth: bool,
|
||||
debug: bool = False,
|
||||
):
|
||||
@@ -345,14 +223,10 @@ def main(
|
||||
tasks = filter(lambda task: task[0].stem in task_ids, tasks)
|
||||
|
||||
if debug:
|
||||
save_as_lerobot_dataset(agibot_world_config, next(tasks), num_threads_per_task, save_depth, debug)
|
||||
save_as_lerobot_dataset(agibot_world_config, next(tasks), save_depth)
|
||||
else:
|
||||
runtime_env = RuntimeEnv(
|
||||
env_vars={
|
||||
"HDF5_USE_FILE_LOCKING": "FALSE",
|
||||
"HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
|
||||
"LD_PRELOAD": str(Path(__file__).resolve().parent / "libtcmalloc.so.4.5.3"),
|
||||
}
|
||||
env_vars={"HDF5_USE_FILE_LOCKING": "FALSE", "HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE"}
|
||||
)
|
||||
ray.init(runtime_env=runtime_env)
|
||||
resources = ray.available_resources()
|
||||
@@ -363,9 +237,7 @@ def main(
|
||||
remote_task = ray.remote(save_as_lerobot_dataset).options(num_cpus=cpus_per_task)
|
||||
futures = []
|
||||
for task in tasks:
|
||||
futures.append(
|
||||
(task[0].stem, remote_task.remote(agibot_world_config, task, num_threads_per_task, save_depth, debug))
|
||||
)
|
||||
futures.append((task[0].stem, remote_task.remote(agibot_world_config, task, save_depth)))
|
||||
|
||||
for task, future in futures:
|
||||
try:
|
||||
@@ -385,7 +257,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--eef-type", type=str, choices=["gripper", "dexhand", "tactile"], default="gripper")
|
||||
parser.add_argument("--task-ids", type=str, nargs="+", help="task_327 task_351 ...", default=[])
|
||||
parser.add_argument("--cpus-per-task", type=int, default=3)
|
||||
parser.add_argument("--num-threads-per-task", type=int, default=2)
|
||||
parser.add_argument("--save-depth", action="store_true")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user