mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-25 10:39:44 +00:00
save action_config in each episode
This commit is contained in:
@@ -6,14 +6,11 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_task_instruction(task_json_path: str) -> dict:
|
||||
"""Get task language instruction"""
|
||||
def get_task_info(task_json_path: str) -> dict:
|
||||
with open(task_json_path, "r") as f:
|
||||
task_info = json.load(f)
|
||||
task_name = task_info[0]["task_name"]
|
||||
task_init_scene = task_info[0]["init_scene_text"]
|
||||
task_instruction = f"{task_name}.{task_init_scene}"
|
||||
return task_instruction
|
||||
task_info: list = json.load(f)
|
||||
task_info.sort(key=lambda episode: episode["episode_id"])
|
||||
return task_info
|
||||
|
||||
|
||||
def load_depths(root_dir: str, camera_name: str):
|
||||
@@ -23,7 +20,7 @@ def load_depths(root_dir: str, camera_name: str):
|
||||
|
||||
|
||||
def load_local_dataset(
|
||||
episode_id: int, src_path: str, task_id: int, task_name: str, save_depth: bool, AgiBotWorld_CONFIG: dict
|
||||
episode_id: int, src_path: str, task_id: int, task_instruction: str, save_depth: bool, AgiBotWorld_CONFIG: dict
|
||||
) -> tuple[list, dict]:
|
||||
"""Load local dataset and return a dict with observations and actions"""
|
||||
ob_dir = Path(src_path) / f"observations/{task_id}/{episode_id}"
|
||||
@@ -79,7 +76,7 @@ def load_local_dataset(
|
||||
)
|
||||
for key, value in action.items()
|
||||
},
|
||||
"task": task_name,
|
||||
"task": task_instruction,
|
||||
}
|
||||
for i in range(num_frames)
|
||||
]
|
||||
@@ -91,4 +88,4 @@ def load_local_dataset(
|
||||
for key in AgiBotWorld_CONFIG["images"]
|
||||
if "depth" not in key
|
||||
}
|
||||
return frames, videos
|
||||
return episode_id, frames, videos
|
||||
|
||||
Reference in New Issue
Block a user