mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-22 09:29:44 +00:00
✨ add version convert collections (#63)
* v20 to v21 * v21 to v20 * v21 to v30 * v16 to v20 * update dataset version convert readme * update readme
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
# LeRobot Dataset v20 to v21
|
||||
|
||||
## Get started
|
||||
|
||||
1. Install v2.1 lerobot
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
git checkout d602e8169cbad9e93a4a3b3ee1dd8b332af7ebf8
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
2. Run the converter:
|
||||
```bash
|
||||
python convert_dataset_v20_to_v21.py \
|
||||
--repo-id=your_id \
|
||||
--root=your_local_dir \
|
||||
--delete-old-stats \
|
||||
--push-to-hub \
|
||||
--num-workers=8
|
||||
```
|
||||
@@ -0,0 +1,88 @@
|
||||
import argparse
|
||||
|
||||
from convert_stats import check_aggregate_stats, convert_stats
|
||||
from huggingface_hub import HfApi
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V20, V21
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
root: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
delete_old_stats: bool = False,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
if root is not None:
|
||||
dataset = LeRobotDataset(repo_id, root, revision=V20)
|
||||
else:
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
dataset.meta.info["codebase_version"] = V21
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if delete_old_stats and (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if delete_old_stats and hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset")
|
||||
if push_to_hub:
|
||||
hub_api.create_tag(repo_id, tag=V21, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the local dataset root directory. If not provided, the script will use the dataset from local.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push the dataset to the hub after conversion. Defaults to False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-old-stats",
|
||||
action="store_true",
|
||||
help="Delete the old stats.json file after conversion. Defaults to False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
@@ -0,0 +1,81 @@
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import write_episode_stats
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
||||
return video_frames[ft_key].numpy()
|
||||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
||||
ep_stats = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if ft["dtype"] == "video":
|
||||
# We sample only for videos
|
||||
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
||||
else:
|
||||
ep_ft_data = np.array(ep_data[key])
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()}
|
||||
|
||||
return ep_stats, ep_idx
|
||||
|
||||
|
||||
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
assert dataset.episodes is None
|
||||
print("Computing episodes stats")
|
||||
total_episodes = dataset.meta.total_episodes
|
||||
if num_workers > 0:
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx for ep_idx in range(total_episodes)
|
||||
}
|
||||
for future in tqdm(as_completed(futures), total=total_episodes):
|
||||
ep_stats, ep_idx = future.result()
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
else:
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
ep_stats, _ = convert_episode_stats(dataset, ep_idx)
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
dataset: LeRobotDataset,
|
||||
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
||||
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
||||
):
|
||||
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
||||
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
||||
for key, ft in dataset.features.items():
|
||||
# These values might need some fine-tuning
|
||||
if ft["dtype"] == "video":
|
||||
# to account for image sub-sampling
|
||||
rtol, atol = video_rtol_atol
|
||||
else:
|
||||
rtol, atol = default_rtol_atol
|
||||
|
||||
for stat, val in agg_stats[key].items():
|
||||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg)
|
||||
Reference in New Issue
Block a user