make script compatible with LeRobotDataset v2.1

This commit is contained in:
Tavish
2025-03-09 10:29:24 +08:00
parent c87238fea0
commit f09ad64230
3 changed files with 22 additions and 48 deletions
+7 -26
View File
@@ -37,7 +37,9 @@ from pathlib import Path
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from huggingface_hub import HfApi
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from oxe_utils.configs import OXE_DATASET_CONFIGS, ActionEncoding, StateEncoding
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
@@ -147,15 +149,10 @@ def save_as_lerobot_dataset(lerobot_dataset: LeRobotDataset, raw_dataset: tf.dat
**image_dict,
"observation.state": traj["proprio"][i],
"action": traj["action"][i],
"task": traj["task"][0].decode(),
}
)
lerobot_dataset.save_episode(task=traj["task"][0].decode())
lerobot_dataset.consolidate(
run_compute_stats=True,
keep_image_files=kwargs["keep_images"],
stat_kwargs={"batch_size": kwargs["batch_size"], "num_workers": kwargs["num_workers"]},
)
lerobot_dataset.save_episode(keep_images=kwargs.get("keep_images", False))
def create_lerobot_dataset(
@@ -166,8 +163,6 @@ def create_lerobot_dataset(
fps: int = None,
robot_type: str = None,
use_videos: bool = True,
batch_size: int = 32,
num_workers: int = 8,
image_writer_process: int = 5,
image_writer_threads: int = 10,
keep_images: bool = True,
@@ -183,7 +178,7 @@ def create_lerobot_dataset(
data_dir = raw_dir.parent
if local_dir is None:
local_dir = Path(LEROBOT_HOME)
local_dir = Path(HF_LEROBOT_HOME)
local_dir /= f"{dataset_name}_{version}_lerobot"
if local_dir.exists():
shutil.rmtree(local_dir)
@@ -221,9 +216,7 @@ def create_lerobot_dataset(
image_writer_processes=image_writer_process,
)
save_as_lerobot_dataset(
lerobot_dataset, raw_dataset, keep_images=keep_images, batch_size=batch_size, num_workers=num_workers
)
save_as_lerobot_dataset(lerobot_dataset, raw_dataset, keep_images=keep_images)
if push_to_hub:
assert repo_id is not None
@@ -282,18 +275,6 @@ def main():
action="store_true",
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size loaded by DataLoader for computing the dataset statistics.",
)
parser.add_argument(
"--num-workers",
type=int,
default=8,
help="Number of processes of Dataloader for computing the dataset statistics.",
)
parser.add_argument(
"--image-writer-process",
type=int,