mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
refactor(datasets): replace untyped dict with typed DatasetInfo dataclass (#3472)
* refactor(datasets): replace untyped dict with typed DatasetInfo dataclass Introduce typed DatasetInfo dataclass to replace untyped dict representation of info.json. Changes: - Add DatasetInfo dataclass with explicit fields and validation - Implement __post_init__ for shape conversion (list ↔ tuple) - Add dict-style compatibility layer (__getitem__, __setitem__, .get()) - Add from_dict() and to_dict() for JSON serialization - Update io_utils to use load_info/write_info with DatasetInfo - Update dataset utilities and metadata to use attribute access - Remove aggregate.py dict-style field access - Add tests fixture support for DatasetInfo Benefits: - Type safety with IDE auto-completion - Validation at construction time - Explicit schema documentation * fix pre-commit * update docstring inside DatasetInfo.from_dict() * sorts the unknown to have deterministic output Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net> * refactoring the last few old fieds * fix crop dataset roi type mismatch * use consistantly int for data and video_files_size_in_mb --------- Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net> Co-authored-by: jjolla93 <jjolla93@gmail.com>
This commit is contained in:
@@ -113,7 +113,7 @@ def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
|
||||
"""Test that metadata is correctly aggregated."""
|
||||
# Test basic info
|
||||
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
|
||||
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
|
||||
assert aggr_ds.meta.info.robot_type == ds_0.meta.info.robot_type == ds_1.meta.info.robot_type, (
|
||||
"Robot type should be the same"
|
||||
)
|
||||
|
||||
@@ -153,8 +153,8 @@ def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
|
||||
video_keys = list(
|
||||
filter(
|
||||
lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video",
|
||||
aggr_ds.meta.info["features"].keys(),
|
||||
lambda key: aggr_ds.meta.info.features[key]["dtype"] == "video",
|
||||
aggr_ds.meta.info.features.keys(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory
|
||||
|
||||
assert meta.total_episodes == 3
|
||||
assert meta.total_frames == 150
|
||||
assert meta.fps == info["fps"]
|
||||
assert meta.fps == info.fps
|
||||
|
||||
|
||||
# ── Property accessors ───────────────────────────────────────────────
|
||||
|
||||
@@ -80,18 +80,18 @@ def _write_dataset_tree(
|
||||
)
|
||||
tasks = tasks_factory(total_tasks=1)
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
features=info.features,
|
||||
fps=info.fps,
|
||||
total_episodes=1,
|
||||
total_frames=3,
|
||||
tasks=tasks,
|
||||
)
|
||||
stats = stats_factory(features=info["features"])
|
||||
stats = stats_factory(features=info.features)
|
||||
hf_dataset = hf_dataset_factory(
|
||||
features=info["features"],
|
||||
features=info.features,
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
fps=info["fps"],
|
||||
fps=info.fps,
|
||||
)
|
||||
|
||||
create_info(root, info)
|
||||
|
||||
Reference in New Issue
Block a user