mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-20 16:39:42 +00:00
save action_config in each episode
This commit is contained in:
+111
-32
@@ -6,48 +6,113 @@ from concurrent.futures import (
|
|||||||
as_completed,
|
as_completed,
|
||||||
)
|
)
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
from agibot_utils.agibot_utils import get_task_instruction, load_local_dataset
|
from agibot_utils.agibot_utils import get_task_info, load_local_dataset
|
||||||
from agibot_utils.config import AgiBotWorld_TASK_TYPE
|
from agibot_utils.config import AgiBotWorld_TASK_TYPE
|
||||||
from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config
|
from agibot_utils.lerobot_utils import compute_episode_stats, generate_features_from_config
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
check_timestamps_sync,
|
check_timestamps_sync,
|
||||||
get_episode_data_index,
|
get_episode_data_index,
|
||||||
validate_episode_buffer,
|
validate_episode_buffer,
|
||||||
validate_frame,
|
validate_frame,
|
||||||
|
write_episode,
|
||||||
|
write_episode_stats,
|
||||||
|
write_info,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import get_safe_default_codec
|
||||||
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from ray.runtime_env import RuntimeEnv
|
from ray.runtime_env import RuntimeEnv
|
||||||
|
|
||||||
|
|
||||||
class AgiBotDataset(LeRobotDataset):
|
class AgiBotDatasetMetadata(LeRobotDatasetMetadata):
|
||||||
def __init__(
|
def save_episode(
|
||||||
self,
|
self,
|
||||||
|
episode_index: int,
|
||||||
|
episode_length: int,
|
||||||
|
episode_tasks: list[str],
|
||||||
|
episode_stats: dict[str, dict],
|
||||||
|
action_config: list[dict],
|
||||||
|
) -> None:
|
||||||
|
self.info["total_episodes"] += 1
|
||||||
|
self.info["total_frames"] += episode_length
|
||||||
|
|
||||||
|
chunk = self.get_episode_chunk(episode_index)
|
||||||
|
if chunk >= self.total_chunks:
|
||||||
|
self.info["total_chunks"] += 1
|
||||||
|
|
||||||
|
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||||
|
self.info["total_videos"] += len(self.video_keys)
|
||||||
|
if len(self.video_keys) > 0:
|
||||||
|
self.update_video_info()
|
||||||
|
|
||||||
|
write_info(self.info, self.root)
|
||||||
|
|
||||||
|
episode_dict = {
|
||||||
|
"episode_index": episode_index,
|
||||||
|
"tasks": episode_tasks,
|
||||||
|
"length": episode_length,
|
||||||
|
"action_config": action_config,
|
||||||
|
}
|
||||||
|
self.episodes[episode_index] = episode_dict
|
||||||
|
write_episode(episode_dict, self.root)
|
||||||
|
|
||||||
|
self.episodes_stats[episode_index] = episode_stats
|
||||||
|
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||||
|
write_episode_stats(episode_index, episode_stats, self.root)
|
||||||
|
|
||||||
|
|
||||||
|
class AgiBotDataset(LeRobotDataset):
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
|
fps: int,
|
||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
episodes: list[int] | None = None,
|
robot: Robot | None = None,
|
||||||
image_transforms: Callable | None = None,
|
robot_type: str | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
features: dict | None = None,
|
||||||
|
use_videos: bool = True,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
download_videos: bool = True,
|
image_writer_processes: int = 0,
|
||||||
local_files_only: bool = False,
|
image_writer_threads: int = 0,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
):
|
) -> "LeRobotDataset":
|
||||||
super().__init__(
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||||
|
obj = cls.__new__(cls)
|
||||||
|
obj.meta = AgiBotDatasetMetadata.create(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
|
fps=fps,
|
||||||
root=root,
|
root=root,
|
||||||
episodes=episodes,
|
robot=robot,
|
||||||
image_transforms=image_transforms,
|
robot_type=robot_type,
|
||||||
delta_timestamps=delta_timestamps,
|
features=features,
|
||||||
tolerance_s=tolerance_s,
|
use_videos=use_videos,
|
||||||
download_videos=download_videos,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
video_backend=video_backend,
|
|
||||||
)
|
)
|
||||||
|
obj.repo_id = obj.meta.repo_id
|
||||||
|
obj.root = obj.meta.root
|
||||||
|
obj.revision = None
|
||||||
|
obj.tolerance_s = tolerance_s
|
||||||
|
obj.image_writer = None
|
||||||
|
|
||||||
|
if image_writer_processes or image_writer_threads:
|
||||||
|
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||||
|
|
||||||
|
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||||
|
obj.episode_buffer = obj.create_episode_buffer()
|
||||||
|
|
||||||
|
obj.episodes = None
|
||||||
|
obj.hf_dataset = obj.create_hf_dataset()
|
||||||
|
obj.image_transforms = None
|
||||||
|
obj.delta_timestamps = None
|
||||||
|
obj.delta_indices = None
|
||||||
|
obj.episode_data_index = None
|
||||||
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
|
return obj
|
||||||
|
|
||||||
def add_frame(self, frame: dict) -> None:
|
def add_frame(self, frame: dict) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -88,7 +153,7 @@ class AgiBotDataset(LeRobotDataset):
|
|||||||
|
|
||||||
self.episode_buffer["size"] += 1
|
self.episode_buffer["size"] += 1
|
||||||
|
|
||||||
def save_episode(self, episode_data: dict | None = None, videos: dict | None = None) -> None:
|
def save_episode(self, videos: dict, action_config: list, episode_data: dict | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
This will save to disk the current episode in self.episode_buffer.
|
This will save to disk the current episode in self.episode_buffer.
|
||||||
|
|
||||||
@@ -138,7 +203,8 @@ class AgiBotDataset(LeRobotDataset):
|
|||||||
self._save_episode_table(episode_buffer, episode_index)
|
self._save_episode_table(episode_buffer, episode_index)
|
||||||
|
|
||||||
# `meta.save_episode` be executed after encoding the videos
|
# `meta.save_episode` be executed after encoding the videos
|
||||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
# add action_config to current episode
|
||||||
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, action_config)
|
||||||
|
|
||||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||||
@@ -165,8 +231,13 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], num_th
|
|||||||
json_file, local_dir = task
|
json_file, local_dir = task
|
||||||
print(f"processing {json_file.stem}, saving to {local_dir}")
|
print(f"processing {json_file.stem}, saving to {local_dir}")
|
||||||
src_path = json_file.parent.parent
|
src_path = json_file.parent.parent
|
||||||
task_name = get_task_instruction(json_file)
|
task_info = get_task_info(json_file)
|
||||||
|
task_name = task_info[0]["task_name"]
|
||||||
|
task_init_scene = task_info[0]["init_scene_text"]
|
||||||
|
task_instruction = f"{task_name}.{task_init_scene}"
|
||||||
task_id = json_file.stem.split("_")[-1]
|
task_id = json_file.stem.split("_")[-1]
|
||||||
|
task_info = {episode["episode_id"]: episode for episode in task_info}
|
||||||
|
|
||||||
features = generate_features_from_config(agibot_world_config)
|
features = generate_features_from_config(agibot_world_config)
|
||||||
|
|
||||||
if local_dir.exists():
|
if local_dir.exists():
|
||||||
@@ -185,27 +256,31 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], num_th
|
|||||||
|
|
||||||
all_subdir = [f.as_posix() for f in src_path.glob(f"observations/{task_id}/*") if f.is_dir()]
|
all_subdir = [f.as_posix() for f in src_path.glob(f"observations/{task_id}/*") if f.is_dir()]
|
||||||
|
|
||||||
all_subdir_eids = [int(Path(path).name) for path in all_subdir]
|
all_subdir_eids = sorted([int(Path(path).name) for path in all_subdir])
|
||||||
|
|
||||||
if debug or not save_depth:
|
if debug or not save_depth:
|
||||||
for eid in all_subdir_eids:
|
for eid in all_subdir_eids:
|
||||||
|
if eid not in task_info:
|
||||||
|
print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...")
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
|
action_config = task_info[eid]["label_info"]["action_config"]
|
||||||
raw_dataset = load_local_dataset(
|
raw_dataset = load_local_dataset(
|
||||||
eid,
|
eid,
|
||||||
src_path=src_path,
|
src_path=src_path,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
task_name=task_name,
|
task_instruction=task_instruction,
|
||||||
save_depth=save_depth,
|
save_depth=save_depth,
|
||||||
AgiBotWorld_CONFIG=agibot_world_config,
|
AgiBotWorld_CONFIG=agibot_world_config,
|
||||||
)
|
)
|
||||||
frames, videos = raw_dataset
|
_, frames, videos = raw_dataset
|
||||||
if not all([video_path.exists() for video_path in videos.values()]):
|
if not all([video_path.exists() for video_path in videos.values()]):
|
||||||
print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping")
|
print(f"{json_file.stem}, episode_{eid}: some of the videos does not exist, skipping...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for frame_data in frames:
|
for frame_data in frames:
|
||||||
dataset.add_frame(frame_data)
|
dataset.add_frame(frame_data)
|
||||||
dataset.save_episode(videos=videos)
|
dataset.save_episode(videos=videos, action_config=action_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"{json_file.stem}, {eid}") from e
|
raise Exception(f"{json_file.stem}, {eid}") from e
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -213,24 +288,28 @@ def save_as_lerobot_dataset(agibot_world_config, task: tuple[Path, Path], num_th
|
|||||||
else:
|
else:
|
||||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||||
futures = []
|
futures = []
|
||||||
for episode_id in all_subdir_eids:
|
for eid in all_subdir_eids:
|
||||||
|
if eid not in task_info:
|
||||||
|
print(f"{json_file.stem}, episode_{eid} not in task_info.json, skipping...")
|
||||||
|
continue
|
||||||
futures.append(
|
futures.append(
|
||||||
executor.submit(
|
executor.submit(
|
||||||
load_local_dataset,
|
load_local_dataset,
|
||||||
episode_id,
|
eid,
|
||||||
src_path=src_path,
|
src_path=src_path,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
task_name=task_name,
|
task_instruction=task_instruction,
|
||||||
save_depth=save_depth,
|
save_depth=save_depth,
|
||||||
AgiBotWorld_CONFIG=agibot_world_config,
|
AgiBotWorld_CONFIG=agibot_world_config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for raw_dataset in as_completed(futures):
|
for raw_dataset in as_completed(futures):
|
||||||
frames, videos = raw_dataset.result()
|
eid, frames, videos = raw_dataset.result()
|
||||||
|
action_config = task_info[eid]["label_info"]["action_config"]
|
||||||
for frame_data in frames:
|
for frame_data in frames:
|
||||||
dataset.add_frame(frame_data)
|
dataset.add_frame(frame_data)
|
||||||
dataset.save_episode(videos=videos)
|
dataset.save_episode(videos=videos, action_config=action_config)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,14 +6,11 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
def get_task_instruction(task_json_path: str) -> dict:
|
def get_task_info(task_json_path: str) -> dict:
|
||||||
"""Get task language instruction"""
|
|
||||||
with open(task_json_path, "r") as f:
|
with open(task_json_path, "r") as f:
|
||||||
task_info = json.load(f)
|
task_info: list = json.load(f)
|
||||||
task_name = task_info[0]["task_name"]
|
task_info.sort(key=lambda episode: episode["episode_id"])
|
||||||
task_init_scene = task_info[0]["init_scene_text"]
|
return task_info
|
||||||
task_instruction = f"{task_name}.{task_init_scene}"
|
|
||||||
return task_instruction
|
|
||||||
|
|
||||||
|
|
||||||
def load_depths(root_dir: str, camera_name: str):
|
def load_depths(root_dir: str, camera_name: str):
|
||||||
@@ -23,7 +20,7 @@ def load_depths(root_dir: str, camera_name: str):
|
|||||||
|
|
||||||
|
|
||||||
def load_local_dataset(
|
def load_local_dataset(
|
||||||
episode_id: int, src_path: str, task_id: int, task_name: str, save_depth: bool, AgiBotWorld_CONFIG: dict
|
episode_id: int, src_path: str, task_id: int, task_instruction: str, save_depth: bool, AgiBotWorld_CONFIG: dict
|
||||||
) -> tuple[list, dict]:
|
) -> tuple[list, dict]:
|
||||||
"""Load local dataset and return a dict with observations and actions"""
|
"""Load local dataset and return a dict with observations and actions"""
|
||||||
ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}"
|
ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}"
|
||||||
@@ -79,7 +76,7 @@ def load_local_dataset(
|
|||||||
)
|
)
|
||||||
for key, value in action.items()
|
for key, value in action.items()
|
||||||
},
|
},
|
||||||
"task": task_name,
|
"task": task_instruction,
|
||||||
}
|
}
|
||||||
for i in range(num_frames)
|
for i in range(num_frames)
|
||||||
]
|
]
|
||||||
@@ -91,4 +88,4 @@ def load_local_dataset(
|
|||||||
for key in AgiBotWorld_CONFIG["images"]
|
for key in AgiBotWorld_CONFIG["images"]
|
||||||
if "depth" not in key
|
if "depth" not in key
|
||||||
}
|
}
|
||||||
return frames, videos
|
return episode_id, frames, videos
|
||||||
|
|||||||
Reference in New Issue
Block a user