mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-16 06:29:45 +00:00
237 lines
9.8 KiB
Python
237 lines
9.8 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import tensorflow_datasets as tfds
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
|
from tensorflow_datasets.core.file_adapters import FileFormat
|
|
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
|
|
from tensorflow_datasets.rlds import rlds_base
|
|
|
|
os.environ["NO_GCE_CHECK"] = "true"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
tfds.core.utils.gcs_utils._is_gcs_disabled = True
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
|
def generate_config_from_features(features, encoding_format, **kwargs):
|
|
action_info = {
|
|
**{
|
|
"_".join(k.split(".")[2:]) or k.split(".")[-1]: tfds.features.Tensor(
|
|
shape=v["shape"], dtype=np.dtype(v["dtype"]), doc=v["names"]
|
|
)
|
|
for k, v in features.items()
|
|
if "action" in k # for compatibility with actions.action_key and action
|
|
},
|
|
}
|
|
action_info = action_info if len(action_info) > 1 else action_info.popitem()[1]
|
|
return dict(
|
|
observation_info={
|
|
**{
|
|
k.split(".")[-1]: tfds.features.Image(
|
|
shape=v["shape"], dtype=np.uint8, encoding_format=encoding_format, doc=v["names"]
|
|
)
|
|
for k, v in features.items()
|
|
if "observation.image" in k and "depth" not in k
|
|
},
|
|
**{
|
|
k.split(".")[-1]: tfds.features.Tensor(shape=v["shape"][:-1], dtype=np.float32, doc=v["names"])
|
|
for k, v in features.items()
|
|
if "observation.image" in k and "depth" in k
|
|
},
|
|
**{
|
|
"_".join(k.split(".")[2:]) or k.split(".")[-1]: tfds.features.Tensor(
|
|
shape=v["shape"], dtype=np.dtype(v["dtype"]), doc=v["names"]
|
|
)
|
|
for k, v in features.items()
|
|
if "observation.state" in k # for compatibility with observation.states.state_key and observation.state
|
|
},
|
|
},
|
|
action_info=action_info,
|
|
step_metadata_info={
|
|
"language_instruction": tfds.features.Text(),
|
|
},
|
|
citation=kwargs.get("citation", ""),
|
|
homepage=kwargs.get("homepage", ""),
|
|
overall_description=kwargs.get("overall_description", ""),
|
|
description=kwargs.get("description", ""),
|
|
)
|
|
|
|
|
|
def parse_step(data_item):
|
|
observation_info = {
|
|
**{
|
|
# lerobot image is (C, H, W) and in range [0, 1]
|
|
k.split(".")[-1]: np.array(v * 255, dtype=np.uint8).transpose(1, 2, 0)
|
|
for k, v in data_item.items()
|
|
if "observation.image" in k and "depth" not in k
|
|
},
|
|
**{
|
|
# lerobot depth is (1, H, W) and in range [0, 1]
|
|
k.split(".")[-1]: v.float().squeeze()
|
|
for k, v in data_item.items()
|
|
if "observation.image" in k and "depth" in k
|
|
},
|
|
**{"_".join(k.split(".")[2:]) or k.split(".")[-1]: v for k, v in data_item.items() if "observation.state" in k},
|
|
}
|
|
action_info = {
|
|
**{"_".join(k.split(".")[2:]) or k.split(".")[-1]: v for k, v in data_item.items() if "action" in k},
|
|
}
|
|
action_info = action_info if len(action_info) > 1 else action_info.popitem()[1]
|
|
|
|
return observation_info, action_info, data_item["task"]
|
|
|
|
|
|
class DatasetBuilder(tfds.core.GeneratorBasedBuilder, skip_registration=True):
|
|
def __init__(self, raw_dir, name, dataset_config, enable_beam, *, file_format=None, **kwargs):
|
|
self.name = name
|
|
self.VERSION = kwargs["version"]
|
|
self.raw_dir = raw_dir
|
|
self.dataset_config = dataset_config
|
|
self.enable_beam = enable_beam
|
|
self.__module__ = "lerobot2rlds"
|
|
super().__init__(file_format=file_format, **kwargs)
|
|
|
|
def _info(self) -> tfds.core.DatasetInfo:
|
|
"""Returns the dataset metadata."""
|
|
return rlds_base.build_info(
|
|
rlds_base.DatasetConfig(
|
|
name=self.name,
|
|
**self.dataset_config,
|
|
),
|
|
self,
|
|
)
|
|
|
|
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
|
"""Returns SplitGenerators."""
|
|
dl_manager._download_dir.rmtree(missing_ok=True)
|
|
return {
|
|
"train": self._generate_examples(),
|
|
}
|
|
|
|
def _generate_examples(self):
|
|
"""Yields examples."""
|
|
|
|
def _generate_examples_beam(episode_index, raw_dir):
|
|
episode = []
|
|
dataset = LeRobotDataset("", raw_dir, episodes=[episode_index])
|
|
logging.info(f"processing episode {episode_index}")
|
|
for data_item in dataset:
|
|
observation_info, action_info, language_instruction = parse_step(data_item)
|
|
episode.append(
|
|
{
|
|
"observation": observation_info,
|
|
"action": action_info,
|
|
"language_instruction": language_instruction,
|
|
"is_first": data_item["frame_index"].item() == 0,
|
|
"is_last": data_item["frame_index"].item()
|
|
== dataset.meta.episodes[episode_index]["length"] - 1,
|
|
"is_terminal": data_item["frame_index"].item()
|
|
== dataset.meta.episodes[episode_index]["length"] - 1,
|
|
}
|
|
)
|
|
return episode_index, {"steps": episode}
|
|
|
|
def _generate_examples_regular():
|
|
dataset = LeRobotDataset("", self.raw_dir)
|
|
episode = []
|
|
current_episode_index = 0
|
|
for data_item in dataset:
|
|
if data_item["episode_index"] != current_episode_index:
|
|
episode[-1]["is_last"] = True
|
|
episode[-1]["is_terminal"] = True
|
|
yield f"{current_episode_index}", {"steps": episode}
|
|
current_episode_index = data_item["episode_index"]
|
|
episode.clear()
|
|
|
|
observation_info, action_info, language_instruction = parse_step(data_item)
|
|
episode.append(
|
|
{
|
|
"observation": observation_info,
|
|
"action": action_info,
|
|
"language_instruction": language_instruction,
|
|
"is_first": data_item["frame_index"].item() == 0,
|
|
"is_last": False,
|
|
"is_terminal": False,
|
|
}
|
|
)
|
|
episode[-1]["is_last"] = True
|
|
episode[-1]["is_terminal"] = True
|
|
yield f"{current_episode_index}", {"steps": episode}
|
|
|
|
if self.enable_beam:
|
|
metadata = LeRobotDatasetMetadata("", self.raw_dir)
|
|
return beam.Create(list(metadata.episodes.keys())) | beam.Map(
|
|
partial(_generate_examples_beam, raw_dir=self.raw_dir)
|
|
)
|
|
else:
|
|
# NOTE: we should return a generator, not yield
|
|
return _generate_examples_regular()
|
|
|
|
|
|
def main(src_dir, output_dir, task_name, version, encoding_format, enable_beam, **kwargs):
|
|
raw_dataset_meta = LeRobotDatasetMetadata("", root=src_dir)
|
|
|
|
dataset_config = generate_config_from_features(raw_dataset_meta.features, encoding_format, **kwargs)
|
|
|
|
dataset_builder = DatasetBuilder(
|
|
raw_dir=src_dir,
|
|
name=task_name,
|
|
data_dir=output_dir,
|
|
version=version,
|
|
dataset_config=dataset_config,
|
|
enable_beam=enable_beam,
|
|
file_format=FileFormat.TFRECORD,
|
|
)
|
|
|
|
if enable_beam:
|
|
logging.warning("beam processing is enabled. Some episodes might be lost, a bug with apache beam.")
|
|
logging.warning("disable beam processing if your dataset is small or you want to save all episodes.")
|
|
from apache_beam.options.pipeline_options import PipelineOptions
|
|
from apache_beam.runners import create_runner
|
|
|
|
if "threading" in kwargs["beam_run_mode"]:
|
|
logging.warning("multi_threading might have issues when sharding and saving.")
|
|
logging.warning("recommend using multi_processing instead.")
|
|
|
|
beam_options = PipelineOptions(
|
|
direct_running_mode=kwargs["beam_run_mode"],
|
|
direct_num_workers=kwargs["beam_num_workers"],
|
|
)
|
|
beam_runner = create_runner("DirectRunner")
|
|
else:
|
|
beam_options = None
|
|
beam_runner = None
|
|
|
|
dataset_builder.download_and_prepare(
|
|
download_config=tfds.download.DownloadConfig(
|
|
try_download_gcs=False,
|
|
verify_ssl=False,
|
|
beam_options=beam_options,
|
|
beam_runner=beam_runner,
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--src-dir", type=Path, help="Path to the local lerobot dataset.")
|
|
parser.add_argument("--output-dir", type=Path, help="Path to the output directory.")
|
|
parser.add_argument("--task-name", type=str, help="Task name.")
|
|
parser.add_argument("--enable-beam", action="store_true", help="Enable beam processing.")
|
|
parser.add_argument("--beam-run-mode", choices=["multi_threading", "multi_processing"], default="multi_processing")
|
|
parser.add_argument("--beam-num-workers", type=int, default=5)
|
|
parser.add_argument("--encoding-format", type=str, choices=["jpeg", "png"], default="jpeg")
|
|
parser.add_argument("--version", type=str, help="x.y.z", default="0.1.0")
|
|
parser.add_argument("--citation", type=str, help="Citation.", default="")
|
|
parser.add_argument("--homepage", type=str, help="Homepage.", default="")
|
|
parser.add_argument("--overall-description", type=str, help="Overall description.", default="")
|
|
parser.add_argument("--description", type=str, help="Description.", default="")
|
|
args = parser.parse_args()
|
|
|
|
main(**vars(args))
|