refactor(datasets): replace untyped dict with typed DatasetInfo dataclass (#3472)

* refactor(datasets): replace untyped dict with typed DatasetInfo dataclass

Introduce typed DatasetInfo dataclass to replace untyped dict representation of info.json.

Changes:
- Add DatasetInfo dataclass with explicit fields and validation
- Implement __post_init__ for shape conversion (list ↔ tuple)
- Add dict-style compatibility layer (__getitem__, __setitem__, .get())
- Add from_dict() and to_dict() for JSON serialization
- Update io_utils to use load_info/write_info with DatasetInfo
- Update dataset utilities and metadata to use attribute access
- Remove aggregate.py dict-style field access
- Add tests fixture support for DatasetInfo

Benefits:
- Type safety with IDE auto-completion
- Validation at construction time
- Explicit schema documentation

* fix pre-commit

* update docstring inside DatasetInfo.from_dict()

* sorts the unknown to have deterministic output

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* refactoring the last few old fieds


* fix crop dataset roi type mismatch


* use consistantly int for data and video_files_size_in_mb

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: jjolla93 <jjolla93@gmail.com>
This commit is contained in:
Maxime Ellerbach
2026-04-28 18:40:30 +02:00
committed by GitHub
parent 8a3d64033f
commit cb0a944941
15 changed files with 275 additions and 163 deletions
+8 -8
View File
@@ -62,19 +62,19 @@ def mock_snapshot_download_factory(
if info is None:
info = info_factory(data_files_size_in_mb=data_files_size_in_mb, chunks_size=chunks_size)
if stats is None:
stats = stats_factory(features=info["features"])
stats = stats_factory(features=info.features)
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
tasks = tasks_factory(total_tasks=info.total_tasks)
if episodes is None:
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
features=info.features,
fps=info.fps,
total_episodes=info.total_episodes,
total_frames=info.total_frames,
tasks=tasks,
)
if hf_dataset is None:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info.fps)
def _mock_snapshot_download(
repo_id: str, # TODO(rcadene): repo_id should be used no?
@@ -97,7 +97,7 @@ def mock_snapshot_download_factory(
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
]
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
video_keys = [key for key, feats in info.features.items() if feats["dtype"] == "video"]
for key in video_keys:
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))