🐛 fix inconsistent image channel

This commit is contained in:
Tavish
2025-07-04 15:19:04 +08:00
parent 9fbb3d1f65
commit 1727054b46
+22 -35
View File
@@ -144,9 +144,7 @@ class RoboMINDDataset(LeRobotDataset):
# Add frame features to episode_buffer
for key, value in frame.items():
if key not in self.features:
raise ValueError(
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
)
raise ValueError(f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'.")
if self.features[key]["dtype"] in ["video"]:
img_path = self._get_image_file_path(
@@ -161,9 +159,7 @@ class RoboMINDDataset(LeRobotDataset):
self.episode_buffer["size"] += 1
def save_episode(
self, split, action_config: dict, episode_data: dict | None = None, keep_images: bool = False
) -> None:
def save_episode(self, split, action_config: dict, episode_data: dict | None = None, keep_images: bool = False) -> None:
"""
This will save to disk the current episode in self.episode_buffer.
@@ -254,35 +250,13 @@ def save_as_lerobot_dataset(task: tuple[dict, Path, str], src_path, benchmark, e
task_type, splits, local_dir, task_instruction = task
config = ROBOMIND_CONFIG[embodiment]
# HACK:
# 1. not consistent image shape...
# 2. franka and ur image is bgr...
features = generate_features_from_config(config)
# [HACK]: franka and ur image is bgr...
bgr2rgb = False
if embodiment in ["franka_1rgb", "franka_3rgb", "franka_fr3_dual", "ur_1rgb"]:
bgr2rgb = True
if "1_0" in benchmark:
match embodiment:
case "tienkung_gello_1rgb":
if task_type in (
"clean_table_2_241211",
"clean_table_3_241210",
"clean_table_3_241211",
"place_paper_cup_dustbin_241212",
"place_plate_table_241211",
"place_plate_table_241211_12",
"place_plate_table_241212",
):
for value in config["images"].values():
value["shape"] = (720, 1280) + (value["shape"][2],)
case "tienkung_xsens_1rgb":
if task_type == "switch_manipulation":
for value in config["images"].values():
value["shape"] = (720, 1280) + (value["shape"][2],)
features = generate_features_from_config(config)
if local_dir.exists():
shutil.rmtree(local_dir)
@@ -312,14 +286,27 @@ def save_as_lerobot_dataset(task: tuple[dict, Path, str], src_path, benchmark, e
for episode_path in path.glob("**/trajectory.hdf5"):
status, raw_dataset, err = load_local_dataset(episode_path, config, save_depth, bgr2rgb)
if status and len(raw_dataset) >= 50:
for frame_data in raw_dataset:
dataset.add_frame(frame_data, task_instruction)
dataset.save_episode(split, action_config.get(episode_path.parent.parent.name, {}))
logging.info(f"process done for {path}, len {len(raw_dataset)}")
try:
for frame_data in raw_dataset:
dataset.add_frame(frame_data, task_instruction)
dataset.save_episode(split, action_config.get(episode_path.parent.parent.name, {}))
logging.info(f"process done for {path}, len {len(raw_dataset)}")
except Exception:
# [HACK]: not consistent image shape...
if config["images"]["camera_top"]["shape"] == (720, 1280, 3):
config["images"]["camera_top"]["shape"] = (480, 640, 3)
config["images"]["camera_top_depth"]["shape"] = (480, 640, 1)
else:
config["images"]["camera_top"]["shape"] = (720, 1280, 3)
config["images"]["camera_top_depth"]["shape"] = (720, 1280, 1)
save_as_lerobot_dataset(task, src_path, benchmark, embodiment, save_depth)
return
else:
logging.warning(f"Skipped {episode_path}: len of dataset:{len(raw_dataset)} or {str(err)}")
gc.collect()
if dataset.meta.total_episodes == 0:
shutil.rmtree(local_dir)
del dataset