mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
refactor(datasets): replace untyped dict with typed DatasetInfo dataclass (#3472)
* refactor(datasets): replace untyped dict with typed DatasetInfo dataclass Introduce typed DatasetInfo dataclass to replace untyped dict representation of info.json. Changes: - Add DatasetInfo dataclass with explicit fields and validation - Implement __post_init__ for shape conversion (list ↔ tuple) - Add dict-style compatibility layer (__getitem__, __setitem__, .get()) - Add from_dict() and to_dict() for JSON serialization - Update io_utils to use load_info/write_info with DatasetInfo - Update dataset utilities and metadata to use attribute access - Remove aggregate.py dict-style field access - Add tests fixture support for DatasetInfo Benefits: - Type safety with IDE auto-completion - Validation at construction time - Explicit schema documentation * fix pre-commit * update docstring inside DatasetInfo.from_dict() * sorts the unknown to have deterministic output Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net> * refactoring the last few old fieds * fix crop dataset roi type mismatch * use consistantly int for data and video_files_size_in_mb --------- Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net> Co-authored-by: jjolla93 <jjolla93@gmail.com>
This commit is contained in:
@@ -113,7 +113,7 @@ def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
|
||||
"""Test that metadata is correctly aggregated."""
|
||||
# Test basic info
|
||||
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
|
||||
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
|
||||
assert aggr_ds.meta.info.robot_type == ds_0.meta.info.robot_type == ds_1.meta.info.robot_type, (
|
||||
"Robot type should be the same"
|
||||
)
|
||||
|
||||
@@ -153,8 +153,8 @@ def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
|
||||
video_keys = list(
|
||||
filter(
|
||||
lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video",
|
||||
aggr_ds.meta.info["features"].keys(),
|
||||
lambda key: aggr_ds.meta.info.features[key]["dtype"] == "video",
|
||||
aggr_ds.meta.info.features.keys(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory
|
||||
|
||||
assert meta.total_episodes == 3
|
||||
assert meta.total_frames == 150
|
||||
assert meta.fps == info["fps"]
|
||||
assert meta.fps == info.fps
|
||||
|
||||
|
||||
# ── Property accessors ───────────────────────────────────────────────
|
||||
|
||||
@@ -80,18 +80,18 @@ def _write_dataset_tree(
|
||||
)
|
||||
tasks = tasks_factory(total_tasks=1)
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
features=info.features,
|
||||
fps=info.fps,
|
||||
total_episodes=1,
|
||||
total_frames=3,
|
||||
tasks=tasks,
|
||||
)
|
||||
stats = stats_factory(features=info["features"])
|
||||
stats = stats_factory(features=info.features)
|
||||
hf_dataset = hf_dataset_factory(
|
||||
features=info["features"],
|
||||
features=info.features,
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
fps=info["fps"],
|
||||
fps=info.fps,
|
||||
)
|
||||
|
||||
create_info(root, info)
|
||||
|
||||
Vendored
+38
-40
@@ -28,7 +28,7 @@ from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.io_utils import flatten_dict, hf_transform_to_torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
@@ -36,10 +36,10 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
DatasetInfo,
|
||||
)
|
||||
from lerobot.datasets.video_utils import encode_video_frames
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
@@ -157,33 +157,31 @@ def info_factory(features_factory):
|
||||
total_episodes: int = 0,
|
||||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
total_videos: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
data_files_size_in_mb: int = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
video_files_size_in_mb: int = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
data_path: str = DEFAULT_DATA_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
) -> DatasetInfo:
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": total_tasks,
|
||||
"total_videos": total_videos,
|
||||
"chunks_size": chunks_size,
|
||||
"data_files_size_in_mb": data_files_size_in_mb,
|
||||
"video_files_size_in_mb": video_files_size_in_mb,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": data_path,
|
||||
"video_path": video_path if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
return DatasetInfo(
|
||||
codebase_version=codebase_version,
|
||||
robot_type=robot_type,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
total_tasks=total_tasks,
|
||||
chunks_size=chunks_size,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
fps=fps,
|
||||
splits={},
|
||||
data_path=data_path,
|
||||
video_path=video_path if use_videos else None,
|
||||
features=features,
|
||||
)
|
||||
|
||||
return _create_info
|
||||
|
||||
@@ -333,12 +331,12 @@ def create_videos(info_factory, img_array_factory):
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
|
||||
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
|
||||
video_feats = {key: feats for key, feats in info.features.items() if feats["dtype"] == "video"}
|
||||
for key, ft in video_feats.items():
|
||||
# create and save images with identifiable content
|
||||
tmp_dir = root / "tmp_images"
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
for frame_index in range(info["total_frames"]):
|
||||
for frame_index in range(info.total_frames):
|
||||
content = f"{key}-{frame_index}"
|
||||
img = img_array_factory(height=ft["shape"][0], width=ft["shape"][1], content=content)
|
||||
pil_img = PIL.Image.fromarray(img)
|
||||
@@ -348,7 +346,7 @@ def create_videos(info_factory, img_array_factory):
|
||||
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Use the global fps from info, not video-specific fps which might not exist
|
||||
encode_video_frames(tmp_dir, video_path, fps=info["fps"])
|
||||
encode_video_frames(tmp_dir, video_path, fps=info.fps)
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
return _create_video_directory
|
||||
@@ -433,16 +431,16 @@ def lerobot_dataset_metadata_factory(
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
stats = stats_factory(features=info.features)
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
tasks = tasks_factory(total_tasks=info.total_tasks)
|
||||
if episodes is None:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
video_keys = [key for key, ft in info.features.items() if ft["dtype"] == "video"]
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
features=info.features,
|
||||
fps=info.fps,
|
||||
total_episodes=info.total_episodes,
|
||||
total_frames=info.total_frames,
|
||||
video_keys=video_keys,
|
||||
tasks=tasks,
|
||||
)
|
||||
@@ -503,23 +501,23 @@ def lerobot_dataset_factory(
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
stats = stats_factory(features=info.features)
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
tasks = tasks_factory(total_tasks=info.total_tasks)
|
||||
if episodes_metadata is None:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
video_keys = [key for key, ft in info.features.items() if ft["dtype"] == "video"]
|
||||
episodes_metadata = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
features=info.features,
|
||||
fps=info.fps,
|
||||
total_episodes=info.total_episodes,
|
||||
total_frames=info.total_frames,
|
||||
video_keys=video_keys,
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory(
|
||||
features=info["features"], tasks=tasks, episodes=episodes_metadata, fps=info["fps"]
|
||||
features=info.features, tasks=tasks, episodes=episodes_metadata, fps=info.fps
|
||||
)
|
||||
|
||||
# Write data on disk
|
||||
|
||||
Vendored
+8
-8
@@ -62,19 +62,19 @@ def mock_snapshot_download_factory(
|
||||
if info is None:
|
||||
info = info_factory(data_files_size_in_mb=data_files_size_in_mb, chunks_size=chunks_size)
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
stats = stats_factory(features=info.features)
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
tasks = tasks_factory(total_tasks=info.total_tasks)
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
features=info.features,
|
||||
fps=info.fps,
|
||||
total_episodes=info.total_episodes,
|
||||
total_frames=info.total_frames,
|
||||
tasks=tasks,
|
||||
)
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info.fps)
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str, # TODO(rcadene): repo_id should be used no?
|
||||
@@ -97,7 +97,7 @@ def mock_snapshot_download_factory(
|
||||
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||
]
|
||||
|
||||
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
|
||||
video_keys = [key for key, feats in info.features.items() if feats["dtype"] == "video"]
|
||||
for key in video_keys:
|
||||
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user