mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-24 02:09:40 +00:00
make script compatible with LeRobotDataset v2.1
This commit is contained in:
+7
-26
@@ -37,7 +37,9 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
from huggingface_hub import HfApi
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
from oxe_utils.configs import OXE_DATASET_CONFIGS, ActionEncoding, StateEncoding
|
||||
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||
@@ -147,15 +149,10 @@ def save_as_lerobot_dataset(lerobot_dataset: LeRobotDataset, raw_dataset: tf.dat
|
||||
**image_dict,
|
||||
"observation.state": traj["proprio"][i],
|
||||
"action": traj["action"][i],
|
||||
"task": traj["task"][0].decode(),
|
||||
}
|
||||
)
|
||||
lerobot_dataset.save_episode(task=traj["task"][0].decode())
|
||||
|
||||
lerobot_dataset.consolidate(
|
||||
run_compute_stats=True,
|
||||
keep_image_files=kwargs["keep_images"],
|
||||
stat_kwargs={"batch_size": kwargs["batch_size"], "num_workers": kwargs["num_workers"]},
|
||||
)
|
||||
lerobot_dataset.save_episode(keep_images=kwargs.get("keep_images", False))
|
||||
|
||||
|
||||
def create_lerobot_dataset(
|
||||
@@ -166,8 +163,6 @@ def create_lerobot_dataset(
|
||||
fps: int = None,
|
||||
robot_type: str = None,
|
||||
use_videos: bool = True,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 8,
|
||||
image_writer_process: int = 5,
|
||||
image_writer_threads: int = 10,
|
||||
keep_images: bool = True,
|
||||
@@ -183,7 +178,7 @@ def create_lerobot_dataset(
|
||||
data_dir = raw_dir.parent
|
||||
|
||||
if local_dir is None:
|
||||
local_dir = Path(LEROBOT_HOME)
|
||||
local_dir = Path(HF_LEROBOT_HOME)
|
||||
local_dir /= f"{dataset_name}_{version}_lerobot"
|
||||
if local_dir.exists():
|
||||
shutil.rmtree(local_dir)
|
||||
@@ -221,9 +216,7 @@ def create_lerobot_dataset(
|
||||
image_writer_processes=image_writer_process,
|
||||
)
|
||||
|
||||
save_as_lerobot_dataset(
|
||||
lerobot_dataset, raw_dataset, keep_images=keep_images, batch_size=batch_size, num_workers=num_workers
|
||||
)
|
||||
save_as_lerobot_dataset(lerobot_dataset, raw_dataset, keep_images=keep_images)
|
||||
|
||||
if push_to_hub:
|
||||
assert repo_id is not None
|
||||
@@ -282,18 +275,6 @@ def main():
|
||||
action="store_true",
|
||||
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size loaded by DataLoader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-writer-process",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user