mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-23 09:49:40 +00:00
make script compatible with lerobot (b536f47) (#38)
* bump openx2lerobot script * bump agibot2lerobot script * bump robomind2lerobot script
This commit is contained in:
@@ -5,7 +5,6 @@ import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import ray
|
||||
@@ -15,8 +14,6 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatas
|
||||
from lerobot.common.datasets.utils import (
|
||||
check_timestamps_sync,
|
||||
get_episode_data_index,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
@@ -24,7 +21,6 @@ from lerobot.common.datasets.utils import (
|
||||
write_info,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import get_safe_default_codec
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from robomind_uitls.configs import ROBOMIND_CONFIG
|
||||
from robomind_uitls.lerobot_uitls import compute_episode_stats, generate_features_from_config
|
||||
@@ -81,10 +77,9 @@ class RoboMINDDataset(LeRobotDataset):
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
root: str | Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
tolerance_s: float = 1e-4,
|
||||
image_writer_processes: int = 0,
|
||||
@@ -96,10 +91,9 @@ class RoboMINDDataset(LeRobotDataset):
|
||||
obj.meta = RoboMINDDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
@@ -123,16 +117,7 @@ class RoboMINDDataset(LeRobotDataset):
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
return obj
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
ft_dict = {col: [] for col in features}
|
||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
def add_frame(self, frame: dict, task: str, timestamp: float | None = None) -> None:
|
||||
"""
|
||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
@@ -150,17 +135,14 @@ class RoboMINDDataset(LeRobotDataset):
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
if timestamp is None:
|
||||
timestamp = frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
self.episode_buffer["task"].append(task)
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key, value in frame.items():
|
||||
if key == "task":
|
||||
# Note: we associate the task in natural language to its task index during `save_episode`
|
||||
self.episode_buffer["task"].append(frame["task"])
|
||||
continue
|
||||
|
||||
if key not in self.features:
|
||||
raise ValueError(
|
||||
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
||||
@@ -276,11 +258,11 @@ def save_as_lerobot_dataset(task: tuple[dict, Path, str], src_path, benchmark, e
|
||||
# 1. not consistent image shape...
|
||||
# 2. franka and ur image is bgr...
|
||||
bgr2rgb = False
|
||||
if embodiment in ["franka_1rgb", "franka_3rgb", "franka_fr3_dual", "ur_1rgb"]:
|
||||
bgr2rgb = True
|
||||
|
||||
if "1_0" in benchmark:
|
||||
match embodiment:
|
||||
case "franka_1rgb" | "franka_3rgb" | "franka_fr3_dual" | "ur_1rgb":
|
||||
bgr2rgb = True
|
||||
|
||||
case "tienkung_gello_1rgb":
|
||||
if task_type in (
|
||||
"clean_table_2_241211",
|
||||
@@ -331,8 +313,7 @@ def save_as_lerobot_dataset(task: tuple[dict, Path, str], src_path, benchmark, e
|
||||
status, raw_dataset, err = load_local_dataset(episode_path, config, save_depth, bgr2rgb)
|
||||
if status and len(raw_dataset) >= 50:
|
||||
for frame_data in raw_dataset:
|
||||
frame_data.update({"task": task_instruction})
|
||||
dataset.add_frame(frame_data)
|
||||
dataset.add_frame(frame_data, task_instruction)
|
||||
dataset.save_episode(split, action_config.get(episode_path.parent.parent.name, {}))
|
||||
logging.info(f"process done for {path}, len {len(raw_dataset)}")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user