Files
2025-07-31 08:51:51 +08:00

237 lines
9.8 KiB
Python

import argparse
import logging
import os
from functools import partial
from pathlib import Path
import numpy as np
import tensorflow_datasets as tfds
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from tensorflow_datasets.core.file_adapters import FileFormat
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
from tensorflow_datasets.rlds import rlds_base
os.environ["NO_GCE_CHECK"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
tfds.core.utils.gcs_utils._is_gcs_disabled = True
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def generate_config_from_features(features, encoding_format, **kwargs):
action_info = {
**{
"_".join(k.split(".")[2:]) or k.split(".")[-1]: tfds.features.Tensor(
shape=v["shape"], dtype=np.dtype(v["dtype"]), doc=v["names"]
)
for k, v in features.items()
if "action" in k # for compatibility with actions.action_key and action
},
}
action_info = action_info if len(action_info) > 1 else action_info.popitem()[1]
return dict(
observation_info={
**{
k.split(".")[-1]: tfds.features.Image(
shape=v["shape"], dtype=np.uint8, encoding_format=encoding_format, doc=v["names"]
)
for k, v in features.items()
if "observation.image" in k and "depth" not in k
},
**{
k.split(".")[-1]: tfds.features.Tensor(shape=v["shape"][:-1], dtype=np.float32, doc=v["names"])
for k, v in features.items()
if "observation.image" in k and "depth" in k
},
**{
"_".join(k.split(".")[2:]) or k.split(".")[-1]: tfds.features.Tensor(
shape=v["shape"], dtype=np.dtype(v["dtype"]), doc=v["names"]
)
for k, v in features.items()
if "observation.state" in k # for compatibility with observation.states.state_key and observation.state
},
},
action_info=action_info,
step_metadata_info={
"language_instruction": tfds.features.Text(),
},
citation=kwargs.get("citation", ""),
homepage=kwargs.get("homepage", ""),
overall_description=kwargs.get("overall_description", ""),
description=kwargs.get("description", ""),
)
def parse_step(data_item):
observation_info = {
**{
# lerobot image is (C, H, W) and in range [0, 1]
k.split(".")[-1]: np.array(v * 255, dtype=np.uint8).transpose(1, 2, 0)
for k, v in data_item.items()
if "observation.image" in k and "depth" not in k
},
**{
# lerobot depth is (1, H, W) and in range [0, 1]
k.split(".")[-1]: v.float().squeeze()
for k, v in data_item.items()
if "observation.image" in k and "depth" in k
},
**{"_".join(k.split(".")[2:]) or k.split(".")[-1]: v for k, v in data_item.items() if "observation.state" in k},
}
action_info = {
**{"_".join(k.split(".")[2:]) or k.split(".")[-1]: v for k, v in data_item.items() if "action" in k},
}
action_info = action_info if len(action_info) > 1 else action_info.popitem()[1]
return observation_info, action_info, data_item["task"]
class DatasetBuilder(tfds.core.GeneratorBasedBuilder, skip_registration=True):
def __init__(self, raw_dir, name, dataset_config, enable_beam, *, file_format=None, **kwargs):
self.name = name
self.VERSION = kwargs["version"]
self.raw_dir = raw_dir
self.dataset_config = dataset_config
self.enable_beam = enable_beam
self.__module__ = "lerobot2rlds"
super().__init__(file_format=file_format, **kwargs)
def _info(self) -> tfds.core.DatasetInfo:
"""Returns the dataset metadata."""
return rlds_base.build_info(
rlds_base.DatasetConfig(
name=self.name,
**self.dataset_config,
),
self,
)
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Returns SplitGenerators."""
dl_manager._download_dir.rmtree(missing_ok=True)
return {
"train": self._generate_examples(),
}
def _generate_examples(self):
"""Yields examples."""
def _generate_examples_beam(episode_index, raw_dir):
episode = []
dataset = LeRobotDataset("", raw_dir, episodes=[episode_index])
logging.info(f"processing episode {episode_index}")
for data_item in dataset:
observation_info, action_info, language_instruction = parse_step(data_item)
episode.append(
{
"observation": observation_info,
"action": action_info,
"language_instruction": language_instruction,
"is_first": data_item["frame_index"].item() == 0,
"is_last": data_item["frame_index"].item()
== dataset.meta.episodes[episode_index]["length"] - 1,
"is_terminal": data_item["frame_index"].item()
== dataset.meta.episodes[episode_index]["length"] - 1,
}
)
return episode_index, {"steps": episode}
def _generate_examples_regular():
dataset = LeRobotDataset("", self.raw_dir)
episode = []
current_episode_index = 0
for data_item in dataset:
if data_item["episode_index"] != current_episode_index:
episode[-1]["is_last"] = True
episode[-1]["is_terminal"] = True
yield f"{current_episode_index}", {"steps": episode}
current_episode_index = data_item["episode_index"]
episode.clear()
observation_info, action_info, language_instruction = parse_step(data_item)
episode.append(
{
"observation": observation_info,
"action": action_info,
"language_instruction": language_instruction,
"is_first": data_item["frame_index"].item() == 0,
"is_last": False,
"is_terminal": False,
}
)
episode[-1]["is_last"] = True
episode[-1]["is_terminal"] = True
yield f"{current_episode_index}", {"steps": episode}
if self.enable_beam:
metadata = LeRobotDatasetMetadata("", self.raw_dir)
return beam.Create(list(metadata.episodes.keys())) | beam.Map(
partial(_generate_examples_beam, raw_dir=self.raw_dir)
)
else:
# NOTE: we should return a generator, not yield
return _generate_examples_regular()
def main(src_dir, output_dir, task_name, version, encoding_format, enable_beam, **kwargs):
raw_dataset_meta = LeRobotDatasetMetadata("", root=src_dir)
dataset_config = generate_config_from_features(raw_dataset_meta.features, encoding_format, **kwargs)
dataset_builder = DatasetBuilder(
raw_dir=src_dir,
name=task_name,
data_dir=output_dir,
version=version,
dataset_config=dataset_config,
enable_beam=enable_beam,
file_format=FileFormat.TFRECORD,
)
if enable_beam:
logging.warning("beam processing is enabled. Some episodes might be lost, a bug with apache beam.")
logging.warning("disable beam processing if your dataset is small or you want to save all episodes.")
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners import create_runner
if "threading" in kwargs["beam_run_mode"]:
logging.warning("multi_threading might have issues when sharding and saving.")
logging.warning("recommend using multi_processing instead.")
beam_options = PipelineOptions(
direct_running_mode=kwargs["beam_run_mode"],
direct_num_workers=kwargs["beam_num_workers"],
)
beam_runner = create_runner("DirectRunner")
else:
beam_options = None
beam_runner = None
dataset_builder.download_and_prepare(
download_config=tfds.download.DownloadConfig(
try_download_gcs=False,
verify_ssl=False,
beam_options=beam_options,
beam_runner=beam_runner,
),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src-dir", type=Path, help="Path to the local lerobot dataset.")
parser.add_argument("--output-dir", type=Path, help="Path to the output directory.")
parser.add_argument("--task-name", type=str, help="Task name.")
parser.add_argument("--enable-beam", action="store_true", help="Enable beam processing.")
parser.add_argument("--beam-run-mode", choices=["multi_threading", "multi_processing"], default="multi_processing")
parser.add_argument("--beam-num-workers", type=int, default=5)
parser.add_argument("--encoding-format", type=str, choices=["jpeg", "png"], default="jpeg")
parser.add_argument("--version", type=str, help="x.y.z", default="0.1.0")
parser.add_argument("--citation", type=str, help="Citation.", default="")
parser.add_argument("--homepage", type=str, help="Homepage.", default="")
parser.add_argument("--overall-description", type=str, help="Overall description.", default="")
parser.add_argument("--description", type=str, help="Description.", default="")
args = parser.parse_args()
main(**vars(args))