add version convert collections (#63)

* v20 to v21

* v21 to v20

* v21 to v30

* v16 to v20

* update dataset version convert readme

* update readme
This commit is contained in:
Qizhi Chen
2025-10-04 16:49:19 +08:00
committed by GitHub
parent 245465975f
commit c5d1312a2b
12 changed files with 1392 additions and 295 deletions
+20
View File
@@ -0,0 +1,20 @@
# LeRobot Dataset v20 to v21
## Get started
1. Install v2.1 lerobot
```bash
git clone https://github.com/huggingface/lerobot.git
git checkout d602e8169cbad9e93a4a3b3ee1dd8b332af7ebf8
pip install -e .
```
2. Run the converter:
```bash
python convert_dataset_v20_to_v21.py \
--repo-id=your_id \
--root=your_local_dir \
--delete-old-stats \
--push-to-hub \
--num-workers=8
```
@@ -0,0 +1,88 @@
import argparse
from convert_stats import check_aggregate_stats, convert_stats
from huggingface_hub import HfApi
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V20, V21
def convert_dataset(
repo_id: str,
root: str | None = None,
push_to_hub: bool = False,
delete_old_stats: bool = False,
branch: str | None = None,
num_workers: int = 4,
):
if root is not None:
dataset = LeRobotDataset(repo_id, root, revision=V20)
else:
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)
dataset.meta.info["codebase_version"] = V21
write_info(dataset.meta.info, dataset.root)
if push_to_hub:
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
# delete old stats.json file
if delete_old_stats and (dataset.root / STATS_PATH).is_file:
(dataset.root / STATS_PATH).unlink()
hub_api = HfApi()
if delete_old_stats and hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset")
if push_to_hub:
hub_api.create_tag(repo_id, tag=V21, revision=branch, repo_type="dataset")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="Path to the local dataset root directory. If not provided, the script will use the dataset from local.",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the dataset to the hub after conversion. Defaults to False.",
)
parser.add_argument(
"--delete-old-stats",
action="store_true",
help="Delete the old stats.json file after conversion. Defaults to False.",
)
parser.add_argument(
"--branch",
type=str,
default=None,
help="Repo branch to push your dataset. Defaults to the main branch.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for parallelizing stats compute. Defaults to 4.",
)
args = parser.parse_args()
convert_dataset(**vars(args))
@@ -0,0 +1,81 @@
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np
from lerobot.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import write_episode_stats
from tqdm import tqdm
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
ep_len = dataset.meta.episodes[episode_index]["length"]
sampled_indices = sample_indices(ep_len)
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
video_frames = dataset._query_videos(query_timestamps, episode_index)
return video_frames[ft_key].numpy()
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
ep_stats = {}
for key, ft in dataset.features.items():
if ft["dtype"] == "video":
# We sample only for videos
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
else:
ep_ft_data = np.array(ep_data[key])
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = {k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()}
return ep_stats, ep_idx
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
assert dataset.episodes is None
print("Computing episodes stats")
total_episodes = dataset.meta.total_episodes
if num_workers > 0:
with ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = {
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx for ep_idx in range(total_episodes)
}
for future in tqdm(as_completed(futures), total=total_episodes):
ep_stats, ep_idx = future.result()
dataset.meta.episodes_stats[ep_idx] = ep_stats
else:
for ep_idx in tqdm(range(total_episodes)):
ep_stats, _ = convert_episode_stats(dataset, ep_idx)
dataset.meta.episodes_stats[ep_idx] = ep_stats
for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats(
dataset: LeRobotDataset,
reference_stats: dict[str, dict[str, np.ndarray]],
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
):
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
for key, ft in dataset.features.items():
# These values might need some fine-tuning
if ft["dtype"] == "video":
# to account for image sub-sampling
rtol, atol = video_rtol_atol
else:
rtol, atol = default_rtol_atol
for stat, val in agg_stats[key].items():
if key in reference_stats and stat in reference_stats[key]:
err_msg = f"feature='{key}' stats='{stat}'"
np.testing.assert_allclose(val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg)