mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
Dataset v3 (#1412)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Remi Cadene <re.cadene@gmail.com> Co-authored-by: Tavish <tavish9.chen@gmail.com> Co-authored-by: fracapuano <francesco.capuano@huggingface.co> Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
This commit is contained in:
Vendored
+74
-60
@@ -14,15 +14,19 @@
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
)
|
||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
|
||||
@@ -30,17 +34,16 @@ from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_snapshot_download_factory(
|
||||
info_factory,
|
||||
info_path,
|
||||
create_info,
|
||||
stats_factory,
|
||||
stats_path,
|
||||
episodes_stats_factory,
|
||||
episodes_stats_path,
|
||||
create_stats,
|
||||
tasks_factory,
|
||||
tasks_path,
|
||||
create_tasks,
|
||||
episodes_factory,
|
||||
episode_path,
|
||||
single_episode_parquet_path,
|
||||
create_episodes,
|
||||
hf_dataset_factory,
|
||||
create_hf_dataset,
|
||||
create_videos,
|
||||
):
|
||||
"""
|
||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||
@@ -50,82 +53,93 @@ def mock_snapshot_download_factory(
|
||||
def _mock_snapshot_download_func(
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
episodes: datasets.Dataset | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
):
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
if info is None:
|
||||
info = info_factory(data_files_size_in_mb=data_files_size_in_mb, chunks_size=chunks_size)
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(
|
||||
features=info["features"], total_episodes=info["total_episodes"]
|
||||
)
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
)
|
||||
if not hf_dataset:
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
||||
return episode_index
|
||||
else:
|
||||
return None
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
repo_id: str, # TODO(rcadene): repo_id should be used no?
|
||||
local_dir: str | Path | None = None,
|
||||
allow_patterns: str | list[str] | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not local_dir:
|
||||
if local_dir is None:
|
||||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
all_files.extend(meta_files)
|
||||
all_files = [
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
# TODO(rcadene): remove naive chunk 0 file 0 ?
|
||||
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||
]
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes.values():
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
|
||||
for key in video_keys:
|
||||
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# Create allowed files
|
||||
request_info = False
|
||||
request_tasks = False
|
||||
request_episodes = False
|
||||
request_stats = False
|
||||
request_data = False
|
||||
request_videos = False
|
||||
for rel_path in allowed_files:
|
||||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats)
|
||||
elif rel_path == EPISODES_STATS_PATH:
|
||||
_ = episodes_stats_path(local_dir, episodes_stats)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes)
|
||||
if rel_path.startswith("meta/info.json"):
|
||||
request_info = True
|
||||
elif rel_path.startswith("meta/stats"):
|
||||
request_stats = True
|
||||
elif rel_path.startswith("meta/tasks"):
|
||||
request_tasks = True
|
||||
elif rel_path.startswith("meta/episodes"):
|
||||
request_episodes = True
|
||||
elif rel_path.startswith("data/"):
|
||||
request_data = True
|
||||
elif rel_path.startswith("videos/"):
|
||||
request_videos = True
|
||||
else:
|
||||
pass
|
||||
raise ValueError(f"{rel_path} not supported.")
|
||||
|
||||
if request_info:
|
||||
create_info(local_dir, info)
|
||||
if request_stats:
|
||||
create_stats(local_dir, stats)
|
||||
if request_tasks:
|
||||
create_tasks(local_dir, tasks)
|
||||
if request_episodes:
|
||||
create_episodes(local_dir, episodes)
|
||||
if request_data:
|
||||
create_hf_dataset(local_dir, hf_dataset, data_files_size_in_mb, chunks_size)
|
||||
if request_videos:
|
||||
create_videos(root=local_dir, info=info)
|
||||
|
||||
return str(local_dir)
|
||||
|
||||
return _mock_snapshot_download
|
||||
|
||||
Reference in New Issue
Block a user