refactor(datasets): replace untyped dict with typed DatasetInfo dataclass (#3472)

* refactor(datasets): replace untyped dict with typed DatasetInfo dataclass

Introduce typed DatasetInfo dataclass to replace untyped dict representation of info.json.

Changes:
- Add DatasetInfo dataclass with explicit fields and validation
- Implement __post_init__ for shape conversion (list ↔ tuple)
- Add dict-style compatibility layer (__getitem__, __setitem__, .get())
- Add from_dict() and to_dict() for JSON serialization
- Update io_utils to use load_info/write_info with DatasetInfo
- Update dataset utilities and metadata to use attribute access
- Remove aggregate.py dict-style field access
- Add tests fixture support for DatasetInfo

Benefits:
- Type safety with IDE auto-completion
- Validation at construction time
- Explicit schema documentation

* fix pre-commit

* update docstring inside DatasetInfo.from_dict()

* sorts the unknown to have deterministic output

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* refactoring the last few old fieds


* fix crop dataset roi type mismatch


* use consistantly int for data and video_files_size_in_mb

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: jjolla93 <jjolla93@gmail.com>
This commit is contained in:
Maxime Ellerbach
2026-04-28 18:40:30 +02:00
committed by GitHub
parent 8a3d64033f
commit cb0a944941
15 changed files with 275 additions and 163 deletions
+3 -3
View File
@@ -113,7 +113,7 @@ def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
"""Test that metadata is correctly aggregated."""
# Test basic info
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
assert aggr_ds.meta.info.robot_type == ds_0.meta.info.robot_type == ds_1.meta.info.robot_type, (
"Robot type should be the same"
)
@@ -153,8 +153,8 @@ def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
video_keys = list(
filter(
lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video",
aggr_ds.meta.info["features"].keys(),
lambda key: aggr_ds.meta.info.features[key]["dtype"] == "video",
aggr_ds.meta.info.features.keys(),
)
)
+1 -1
View File
@@ -161,7 +161,7 @@ def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory
assert meta.total_episodes == 3
assert meta.total_frames == 150
assert meta.fps == info["fps"]
assert meta.fps == info.fps
# ── Property accessors ───────────────────────────────────────────────
+5 -5
View File
@@ -80,18 +80,18 @@ def _write_dataset_tree(
)
tasks = tasks_factory(total_tasks=1)
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
features=info.features,
fps=info.fps,
total_episodes=1,
total_frames=3,
tasks=tasks,
)
stats = stats_factory(features=info["features"])
stats = stats_factory(features=info.features)
hf_dataset = hf_dataset_factory(
features=info["features"],
features=info.features,
tasks=tasks,
episodes=episodes,
fps=info["fps"],
fps=info.fps,
)
create_info(root, info)
+38 -40
View File
@@ -28,7 +28,7 @@ from datasets import Dataset
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.io_utils import flatten_dict, hf_transform_to_torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
@@ -36,10 +36,10 @@ from lerobot.datasets.utils import (
DEFAULT_DATA_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
DatasetInfo,
)
from lerobot.datasets.video_utils import encode_video_frames
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import flatten_dict
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
@@ -157,33 +157,31 @@ def info_factory(features_factory):
total_episodes: int = 0,
total_frames: int = 0,
total_tasks: int = 0,
total_videos: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE,
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
data_files_size_in_mb: int = DEFAULT_DATA_FILE_SIZE_IN_MB,
video_files_size_in_mb: int = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
data_path: str = DEFAULT_DATA_PATH,
video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
) -> DatasetInfo:
features = features_factory(motor_features, camera_features, use_videos)
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": total_frames,
"total_tasks": total_tasks,
"total_videos": total_videos,
"chunks_size": chunks_size,
"data_files_size_in_mb": data_files_size_in_mb,
"video_files_size_in_mb": video_files_size_in_mb,
"fps": fps,
"splits": {},
"data_path": data_path,
"video_path": video_path if use_videos else None,
"features": features,
}
return DatasetInfo(
codebase_version=codebase_version,
robot_type=robot_type,
total_episodes=total_episodes,
total_frames=total_frames,
total_tasks=total_tasks,
chunks_size=chunks_size,
data_files_size_in_mb=data_files_size_in_mb,
video_files_size_in_mb=video_files_size_in_mb,
fps=fps,
splits={},
data_path=data_path,
video_path=video_path if use_videos else None,
features=features,
)
return _create_info
@@ -333,12 +331,12 @@ def create_videos(info_factory, img_array_factory):
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
video_feats = {key: feats for key, feats in info.features.items() if feats["dtype"] == "video"}
for key, ft in video_feats.items():
# create and save images with identifiable content
tmp_dir = root / "tmp_images"
tmp_dir.mkdir(parents=True, exist_ok=True)
for frame_index in range(info["total_frames"]):
for frame_index in range(info.total_frames):
content = f"{key}-{frame_index}"
img = img_array_factory(height=ft["shape"][0], width=ft["shape"][1], content=content)
pil_img = PIL.Image.fromarray(img)
@@ -348,7 +346,7 @@ def create_videos(info_factory, img_array_factory):
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
video_path.parent.mkdir(parents=True, exist_ok=True)
# Use the global fps from info, not video-specific fps which might not exist
encode_video_frames(tmp_dir, video_path, fps=info["fps"])
encode_video_frames(tmp_dir, video_path, fps=info.fps)
shutil.rmtree(tmp_dir)
return _create_video_directory
@@ -433,16 +431,16 @@ def lerobot_dataset_metadata_factory(
if info is None:
info = info_factory()
if stats is None:
stats = stats_factory(features=info["features"])
stats = stats_factory(features=info.features)
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
tasks = tasks_factory(total_tasks=info.total_tasks)
if episodes is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
video_keys = [key for key, ft in info.features.items() if ft["dtype"] == "video"]
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
features=info.features,
fps=info.fps,
total_episodes=info.total_episodes,
total_frames=info.total_frames,
video_keys=video_keys,
tasks=tasks,
)
@@ -503,23 +501,23 @@ def lerobot_dataset_factory(
chunks_size=chunks_size,
)
if stats is None:
stats = stats_factory(features=info["features"])
stats = stats_factory(features=info.features)
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
tasks = tasks_factory(total_tasks=info.total_tasks)
if episodes_metadata is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
video_keys = [key for key, ft in info.features.items() if ft["dtype"] == "video"]
episodes_metadata = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
features=info.features,
fps=info.fps,
total_episodes=info.total_episodes,
total_frames=info.total_frames,
video_keys=video_keys,
tasks=tasks,
multi_task=multi_task,
)
if hf_dataset is None:
hf_dataset = hf_dataset_factory(
features=info["features"], tasks=tasks, episodes=episodes_metadata, fps=info["fps"]
features=info.features, tasks=tasks, episodes=episodes_metadata, fps=info.fps
)
# Write data on disk
+8 -8
View File
@@ -62,19 +62,19 @@ def mock_snapshot_download_factory(
if info is None:
info = info_factory(data_files_size_in_mb=data_files_size_in_mb, chunks_size=chunks_size)
if stats is None:
stats = stats_factory(features=info["features"])
stats = stats_factory(features=info.features)
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
tasks = tasks_factory(total_tasks=info.total_tasks)
if episodes is None:
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
features=info.features,
fps=info.fps,
total_episodes=info.total_episodes,
total_frames=info.total_frames,
tasks=tasks,
)
if hf_dataset is None:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info.fps)
def _mock_snapshot_download(
repo_id: str, # TODO(rcadene): repo_id should be used no?
@@ -97,7 +97,7 @@ def mock_snapshot_download_factory(
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
]
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
video_keys = [key for key, feats in info.features.items() if feats["dtype"] == "video"]
for key in video_keys:
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))