diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 0da1da964..cac009b64 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -97,8 +97,8 @@ def update_data_df(df, src_meta, dst_meta): pd.DataFrame: Updated DataFrame with adjusted indices. """ - df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"] - df["index"] = df["index"] + dst_meta.info["total_frames"] + df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes + df["index"] = df["index"] + dst_meta.info.total_frames src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy()) df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy() @@ -225,9 +225,9 @@ def update_meta_data( # Clean up temporary columns df = df.drop(columns=["_orig_chunk", "_orig_file"]) - df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"] - df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"] - df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"] + df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info.total_frames + df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info.total_frames + df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes return df @@ -313,8 +313,8 @@ def aggregate_datasets( # to avoid interference between different source datasets data_idx.pop("src_to_dst", None) - dst_meta.info["total_episodes"] += src_meta.total_episodes - dst_meta.info["total_frames"] += src_meta.total_frames + dst_meta.info.total_episodes += src_meta.total_episodes + dst_meta.info.total_frames += src_meta.total_frames finalize_aggregation(dst_meta, all_metadata) logging.info("Aggregation complete.") @@ -640,14 +640,10 @@ def finalize_aggregation(aggr_meta, all_metadata): write_tasks(aggr_meta.tasks, aggr_meta.root) logging.info("write info") - aggr_meta.info.update( - { - "total_tasks": len(aggr_meta.tasks), - "total_episodes": sum(m.total_episodes for m in all_metadata), - "total_frames": sum(m.total_frames for m in all_metadata), - "splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}, - } - ) + aggr_meta.info.total_tasks = len(aggr_meta.tasks) + aggr_meta.info.total_episodes = sum(m.total_episodes for m in all_metadata) + aggr_meta.info.total_frames = sum(m.total_frames for m in all_metadata) + aggr_meta.info.splits = {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"} write_info(aggr_meta.info, aggr_meta.root) logging.info("write stats") diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 8bf67fa39..4f89ba2a4 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -37,13 +37,11 @@ from .io_utils import ( load_subtasks, load_tasks, write_info, - write_json, write_stats, write_tasks, ) from .utils import ( DEFAULT_EPISODES_PATH, - INFO_PATH, check_version_compatibility, get_safe_version, has_legacy_hub_download_metadata, @@ -228,7 +226,7 @@ class LeRobotDatasetMetadata: @property def _version(self) -> packaging.version.Version: """Codebase version used to create this dataset.""" - return packaging.version.parse(self.info["codebase_version"]) + return packaging.version.parse(self.info.codebase_version) def get_data_file_path(self, ep_index: int) -> Path: """Return the relative parquet file path for the given episode index. @@ -283,27 +281,27 @@ class LeRobotDatasetMetadata: @property def data_path(self) -> str: """Formattable string for the parquet files.""" - return self.info["data_path"] + return self.info.data_path @property def video_path(self) -> str | None: """Formattable string for the video files.""" - return self.info["video_path"] + return self.info.video_path @property def robot_type(self) -> str | None: """Robot type used in recording this dataset.""" - return self.info["robot_type"] + return self.info.robot_type @property def fps(self) -> int: """Frames per second used during data collection.""" - return self.info["fps"] + return self.info.fps @property def features(self) -> dict[str, dict]: """All features contained in the dataset.""" - return self.info["features"] + return self.info.features @property def image_keys(self) -> list[str]: @@ -333,32 +331,32 @@ class LeRobotDatasetMetadata: @property def total_episodes(self) -> int: """Total number of episodes available.""" - return self.info["total_episodes"] + return self.info.total_episodes @property def total_frames(self) -> int: """Total number of frames saved in this dataset.""" - return self.info["total_frames"] + return self.info.total_frames @property def total_tasks(self) -> int: """Total number of different tasks performed in this dataset.""" - return self.info["total_tasks"] + return self.info.total_tasks @property def chunks_size(self) -> int: """Max number of files per chunk.""" - return self.info["chunks_size"] + return self.info.chunks_size @property def data_files_size_in_mb(self) -> int: """Max size of data file in mega bytes.""" - return self.info["data_files_size_in_mb"] + return self.info.data_files_size_in_mb @property def video_files_size_in_mb(self) -> int: """Max size of video file in mega bytes.""" - return self.info["video_files_size_in_mb"] + return self.info.video_files_size_in_mb def get_task_index(self, task: str) -> int | None: """ @@ -502,10 +500,10 @@ class LeRobotDatasetMetadata: self._save_episode_metadata(episode_dict) # Update info - self.info["total_episodes"] += 1 - self.info["total_frames"] += episode_length - self.info["total_tasks"] = len(self.tasks) - self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} + self.info.total_episodes += 1 + self.info.total_frames += episode_length + self.info.total_tasks = len(self.tasks) + self.info.splits = {"train": f"0:{self.info.total_episodes}"} write_info(self.info, self.root) @@ -524,7 +522,7 @@ class LeRobotDatasetMetadata: for key in video_keys: if not self.features[key].get("info", None): video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) - self.info["features"][key]["info"] = get_video_info(video_path) + self.info.features[key]["info"] = get_video_info(video_path) def update_chunk_settings( self, @@ -546,17 +544,17 @@ class LeRobotDatasetMetadata: if chunks_size is not None: if chunks_size <= 0: raise ValueError(f"chunks_size must be positive, got {chunks_size}") - self.info["chunks_size"] = chunks_size + self.info.chunks_size = chunks_size if data_files_size_in_mb is not None: if data_files_size_in_mb <= 0: raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}") - self.info["data_files_size_in_mb"] = data_files_size_in_mb + self.info.data_files_size_in_mb = data_files_size_in_mb if video_files_size_in_mb is not None: if video_files_size_in_mb <= 0: raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}") - self.info["video_files_size_in_mb"] = video_files_size_in_mb + self.info.video_files_size_in_mb = video_files_size_in_mb # Update the info file on disk write_info(self.info, self.root) @@ -653,7 +651,7 @@ class LeRobotDatasetMetadata: f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. " "Either remove video features from the features dict, or set 'use_videos=True'." ) - write_json(obj.info, obj.root / INFO_PATH) + write_info(obj.info, obj.root) obj.revision = None obj._pq_writer = None obj.latest_episode = None diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index cbf4e5c49..46dd9bff2 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -897,14 +897,10 @@ def _copy_and_reindex_episodes_metadata( dst_meta.finalize() - dst_meta.info.update( - { - "total_episodes": len(episode_mapping), - "total_frames": total_frames, - "total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0, - "splits": {"train": f"0:{len(episode_mapping)}"}, - } - ) + dst_meta.info.total_episodes = len(episode_mapping) + dst_meta.info.total_frames = total_frames + dst_meta.info.total_tasks = len(dst_meta.tasks) if dst_meta.tasks is not None else 0 + dst_meta.info.splits = {"train": f"0:{len(episode_mapping)}"} write_info(dst_meta.info, dst_meta.root) if not all_stats: @@ -1069,21 +1065,20 @@ def _copy_episodes_metadata_and_stats( if episodes_dir.exists(): shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True) - dst_meta.info.update( - { - "total_episodes": src_dataset.meta.total_episodes, - "total_frames": src_dataset.meta.total_frames, - "total_tasks": src_dataset.meta.total_tasks, - "splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}), - } + dst_meta.info.total_episodes = src_dataset.meta.total_episodes + dst_meta.info.total_frames = src_dataset.meta.total_frames + dst_meta.info.total_tasks = src_dataset.meta.total_tasks + # Preserve original splits if available, otherwise create default + dst_meta.info.splits = ( + src_dataset.meta.info.splits + if src_dataset.meta.info.splits + else {"train": f"0:{src_dataset.meta.total_episodes}"} ) if dst_meta.video_keys and src_dataset.meta.video_keys: for key in dst_meta.video_keys: if key in src_dataset.meta.features: - dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get( - "info", {} - ) + dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {}) write_info(dst_meta.info, dst_meta.root) @@ -1525,7 +1520,7 @@ def modify_tasks( write_tasks(new_task_df, root) # Update info.json - dataset.meta.info["total_tasks"] = len(unique_tasks) + dataset.meta.info.total_tasks = len(unique_tasks) write_info(dataset.meta.info, root) # Reload metadata to reflect changes @@ -1858,10 +1853,10 @@ def convert_image_to_video_dataset( episodes_df.to_parquet(episodes_path, index=False) # Update metadata info - new_meta.info["total_episodes"] = len(episode_indices) - new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values()) - new_meta.info["total_tasks"] = dataset.meta.total_tasks - new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"} + new_meta.info.total_episodes = len(episode_indices) + new_meta.info.total_frames = sum(ep["length"] for ep in all_episode_metadata.values()) + new_meta.info.total_tasks = dataset.meta.total_tasks + new_meta.info.splits = {"train": f"0:{len(episode_indices)}"} # Update video info for all image keys (now videos) # We need to manually set video info since update_video_info() checks video_keys first @@ -1870,7 +1865,7 @@ def convert_image_to_video_dataset( video_path = new_meta.root / new_meta.video_path.format( video_key=img_key, chunk_index=0, file_index=0 ) - new_meta.info["features"][img_key]["info"] = get_video_info(video_path) + new_meta.info.features[img_key]["info"] = get_video_info(video_path) write_info(new_meta.info, new_meta.root) diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py index b05dbf2cc..2ab4b0ea6 100644 --- a/src/lerobot/datasets/feature_utils.py +++ b/src/lerobot/datasets/feature_utils.py @@ -28,6 +28,7 @@ from .utils import ( DEFAULT_DATA_PATH, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, + DatasetInfo, ) @@ -78,8 +79,8 @@ def create_empty_dataset_info( chunks_size: int | None = None, data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, -) -> dict: - """Create a template dictionary for a new dataset's `info.json`. +) -> DatasetInfo: + """Create a template ``DatasetInfo`` object for a new dataset's ``meta/info.json``. Args: codebase_version (str): The version of the LeRobot codebase. @@ -87,25 +88,24 @@ def create_empty_dataset_info( features (dict): The LeRobot features dictionary for the dataset. use_videos (bool): Whether the dataset will store videos. robot_type (str | None): The type of robot used, if any. + chunks_size (int | None): Max files per chunk directory. Defaults to ``DEFAULT_CHUNK_SIZE``. + data_files_size_in_mb (int | None): Max parquet file size in MB. Defaults to ``DEFAULT_DATA_FILE_SIZE_IN_MB``. + video_files_size_in_mb (int | None): Max video file size in MB. Defaults to ``DEFAULT_VIDEO_FILE_SIZE_IN_MB``. Returns: - dict: A dictionary with the initial dataset metadata. + DatasetInfo: A typed dataset information object with initial metadata. """ - return { - "codebase_version": codebase_version, - "robot_type": robot_type, - "total_episodes": 0, - "total_frames": 0, - "total_tasks": 0, - "chunks_size": chunks_size or DEFAULT_CHUNK_SIZE, - "data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, - "video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, - "fps": fps, - "splits": {}, - "data_path": DEFAULT_DATA_PATH, - "video_path": DEFAULT_VIDEO_PATH if use_videos else None, - "features": features, - } + return DatasetInfo( + codebase_version=codebase_version, + fps=fps, + features=features, + robot_type=robot_type, + chunks_size=chunks_size or DEFAULT_CHUNK_SIZE, + data_files_size_in_mb=data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, + video_files_size_in_mb=video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, + data_path=DEFAULT_DATA_PATH, + video_path=DEFAULT_VIDEO_PATH if use_videos else None, + ) def check_delta_timestamps( diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py index 2ee859e97..f5681c7c0 100644 --- a/src/lerobot/datasets/io_utils.py +++ b/src/lerobot/datasets/io_utils.py @@ -39,6 +39,7 @@ from .utils import ( EPISODES_DIR, INFO_PATH, STATS_PATH, + DatasetInfo, serialize_dict, ) @@ -115,25 +116,21 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: return dataset -def write_info(info: dict, local_dir: Path) -> None: - write_json(info, local_dir / INFO_PATH) +def write_info(info: DatasetInfo, local_dir: Path) -> None: + write_json(info.to_dict(), local_dir / INFO_PATH) -def load_info(local_dir: Path) -> dict: +def load_info(local_dir: Path) -> DatasetInfo: """Load dataset info metadata from its standard file path. - Also converts shape lists to tuples for consistency. - Args: local_dir (Path): The root directory of the dataset. Returns: - dict: The dataset information dictionary. + DatasetInfo: The typed dataset information object. """ - info = load_json(local_dir / INFO_PATH) - for ft in info["features"].values(): - ft["shape"] = tuple(ft["shape"]) - return info + raw = load_json(local_dir / INFO_PATH) + return DatasetInfo.from_dict(raw) def write_stats(stats: dict, local_dir: Path) -> None: diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 4de2ed69c..3c1e4a73c 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -434,7 +434,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): def _make_padding_camera_frame(self, camera_key: str): """Variable-shape padding frame for given camera keys, given in (H, W, C)""" - return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1) + return torch.zeros(self.meta.info.features[camera_key]["shape"]).permute(-1, 0, 1) def _get_video_frame_padding_mask( self, diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index c6815e0f5..93507ae71 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +import dataclasses import importlib.resources import json import logging +from dataclasses import dataclass, field from pathlib import Path import datasets @@ -70,6 +72,9 @@ class ForwardCompatibilityError(CompatibilityError): super().__init__(message) +logger = logging.getLogger(__name__) + + DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file @@ -94,6 +99,122 @@ LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" LEGACY_TASKS_PATH = "meta/tasks.jsonl" +@dataclass +class DatasetInfo: + """Typed representation of the ``meta/info.json`` file for a LeRobot dataset. + + Replaces the previously untyped ``dict`` returned by ``load_info()`` and + created by ``create_empty_dataset_info()``. Using a dataclass provides + explicit field definitions, IDE auto-completion, and validation at + construction time. + """ + + codebase_version: str + fps: int + features: dict[str, dict] + + # Episode / frame counters — start at zero for new datasets + total_episodes: int = 0 + total_frames: int = 0 + total_tasks: int = 0 + + # Storage settings + chunks_size: int = field(default=DEFAULT_CHUNK_SIZE) + data_files_size_in_mb: int = field(default=DEFAULT_DATA_FILE_SIZE_IN_MB) + video_files_size_in_mb: int = field(default=DEFAULT_VIDEO_FILE_SIZE_IN_MB) + + # File path templates + data_path: str = field(default=DEFAULT_DATA_PATH) + video_path: str | None = field(default=DEFAULT_VIDEO_PATH) + + # Optional metadata + robot_type: str | None = None + splits: dict[str, str] = field(default_factory=dict) + + def __post_init__(self) -> None: + # Coerce feature shapes from list to tuple — JSON deserialisation + # returns lists, but the rest of the codebase expects tuples. + for ft in self.features.values(): + if isinstance(ft.get("shape"), list): + ft["shape"] = tuple(ft["shape"]) + + if self.fps <= 0: + raise ValueError(f"fps must be positive, got {self.fps}") + if self.chunks_size <= 0: + raise ValueError(f"chunks_size must be positive, got {self.chunks_size}") + if self.data_files_size_in_mb <= 0: + raise ValueError(f"data_files_size_in_mb must be positive, got {self.data_files_size_in_mb}") + if self.video_files_size_in_mb <= 0: + raise ValueError(f"video_files_size_in_mb must be positive, got {self.video_files_size_in_mb}") + + def to_dict(self) -> dict: + """Return a JSON-serialisable dict. + + Converts tuple shapes back to lists so ``json.dump`` can handle them. + """ + d = dataclasses.asdict(self) + for ft in d["features"].values(): + if isinstance(ft.get("shape"), tuple): + ft["shape"] = list(ft["shape"]) + return d + + @classmethod + def from_dict(cls, data: dict) -> "DatasetInfo": + """Construct from a raw dict (e.g. loaded directly from JSON). + + Unknown keys are silently ignored for forward compatibility with + datasets that carry additional fields (e.g. ``total_videos`` from v2.x). + """ + known = {f.name for f in dataclasses.fields(cls)} + unknown = {k for k in data if k not in known} + if unknown: + logger.warning(f"Unknown fields in DatasetInfo: {unknown}. These will be ignored.") + return cls(**{k: v for k, v in data.items() if k in known}) + + # --------------------------------------------------------------------------- + # Temporary dict-style compatibility layer + # Allows existing ``info["key"]`` call-sites to keep working without changes. + # Once all callers have been migrated to attribute access, remove these. + # --------------------------------------------------------------------------- + def __getitem__(self, key: str): + import warnings + + warnings.warn( + f"Accessing DatasetInfo with dict-style syntax info['{key}'] is deprecated. " + f"Use attribute access info.{key} instead.", + DeprecationWarning, + stacklevel=2, + ) + try: + return getattr(self, key) + except AttributeError as err: + raise KeyError(key) from err + + def __setitem__(self, key: str, value) -> None: + import warnings + + warnings.warn( + f"Setting DatasetInfo with dict-style syntax info['{key}'] = ... is deprecated. " + f"Use attribute assignment info.{key} = ... instead.", + DeprecationWarning, + stacklevel=2, + ) + if not hasattr(self, key): + raise KeyError(f"DatasetInfo has no field '{key}'") + setattr(self, key, value) + + def __contains__(self, key: str) -> bool: + """Check if a field exists (dict-like interface).""" + return hasattr(self, key) + + def get(self, key: str, default=None): + """Get attribute value with default fallback (dict-like interface).""" + try: + return getattr(self, key) + except AttributeError: + return default + + def has_legacy_hub_download_metadata(root: Path) -> bool: """Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror. diff --git a/src/lerobot/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py index b6bde2273..44614f75f 100644 --- a/src/lerobot/rl/crop_dataset_roi.py +++ b/src/lerobot/rl/crop_dataset_roi.py @@ -193,15 +193,15 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset( fps=int(original_dataset.fps), root=new_dataset_root, robot_type=original_dataset.meta.robot_type, - features=original_dataset.meta.info["features"], + features=original_dataset.meta.info.features, use_videos=len(original_dataset.meta.video_keys) > 0, ) # Update the metadata for every image key that will be cropped: # (Here we simply set the shape to be the final resize_size.) for key in crop_params_dict: - if key in new_dataset.meta.info["features"]: - new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size) + if key in new_dataset.meta.info.features: + new_dataset.meta.info.features[key]["shape"] = [3] + list(resize_size) # TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset prev_episode_index = 0 diff --git a/src/lerobot/scripts/convert_dataset_v21_to_v30.py b/src/lerobot/scripts/convert_dataset_v21_to_v30.py index 1f220dc20..26f341f55 100644 --- a/src/lerobot/scripts/convert_dataset_v21_to_v30.py +++ b/src/lerobot/scripts/convert_dataset_v21_to_v30.py @@ -70,6 +70,8 @@ from lerobot.datasets.io_utils import ( get_parquet_file_size_in_mb, get_parquet_num_frames, load_info, + load_json, + flatten_dict, write_episodes, write_info, write_stats, @@ -81,9 +83,11 @@ from lerobot.datasets.utils import ( DEFAULT_DATA_PATH, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, + INFO_PATH, LEGACY_EPISODES_PATH, LEGACY_EPISODES_STATS_PATH, LEGACY_TASKS_PATH, + DatasetInfo, update_chunk_file_indices, ) from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s @@ -256,14 +260,14 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int): def get_video_keys(root): info = load_info(root) - features = info["features"] + features = info.features video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"] return video_keys def get_image_keys(root): info = load_info(root) - features = info["features"] + features = info.features image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"] return image_keys @@ -434,7 +438,8 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_ def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb): - info = load_info(root) + # Load as raw dict to remove legacy v2.1 fields before constructing DatasetInfo. + info = load_json(root / INFO_PATH) info["codebase_version"] = V30 del info["total_chunks"] del info["total_videos"] @@ -449,7 +454,9 @@ def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb): # already has fps in video_info continue info["features"][key]["fps"] = info["fps"] - write_info(info, new_root) + # Convert raw dict to typed DatasetInfo before writing + dataset_info = DatasetInfo.from_dict(info) + write_info(dataset_info, new_root) def convert_dataset( diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index b74299311..6d646d4f7 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -113,7 +113,7 @@ def assert_metadata_consistency(aggr_ds, ds_0, ds_1): """Test that metadata is correctly aggregated.""" # Test basic info assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets" - assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], ( + assert aggr_ds.meta.info.robot_type == ds_0.meta.info.robot_type == ds_1.meta.info.robot_type, ( "Robot type should be the same" ) @@ -153,8 +153,8 @@ def assert_video_frames_integrity(aggr_ds, ds_0, ds_1): video_keys = list( filter( - lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video", - aggr_ds.meta.info["features"].keys(), + lambda key: aggr_ds.meta.info.features[key]["dtype"] == "video", + aggr_ds.meta.info.features.keys(), ) ) diff --git a/tests/datasets/test_dataset_metadata.py b/tests/datasets/test_dataset_metadata.py index 6db41d05c..6c784c90b 100644 --- a/tests/datasets/test_dataset_metadata.py +++ b/tests/datasets/test_dataset_metadata.py @@ -161,7 +161,7 @@ def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory assert meta.total_episodes == 3 assert meta.total_frames == 150 - assert meta.fps == info["fps"] + assert meta.fps == info.fps # ── Property accessors ─────────────────────────────────────────────── diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py index 26406dea2..f3bda037f 100644 --- a/tests/datasets/test_lerobot_dataset.py +++ b/tests/datasets/test_lerobot_dataset.py @@ -80,18 +80,18 @@ def _write_dataset_tree( ) tasks = tasks_factory(total_tasks=1) episodes = episodes_factory( - features=info["features"], - fps=info["fps"], + features=info.features, + fps=info.fps, total_episodes=1, total_frames=3, tasks=tasks, ) - stats = stats_factory(features=info["features"]) + stats = stats_factory(features=info.features) hf_dataset = hf_dataset_factory( - features=info["features"], + features=info.features, tasks=tasks, episodes=episodes, - fps=info["fps"], + fps=info.fps, ) create_info(root, info) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index e068484b0..4212453b0 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -28,7 +28,7 @@ from datasets import Dataset from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata from lerobot.datasets.feature_utils import get_hf_features_from_features -from lerobot.datasets.io_utils import hf_transform_to_torch +from lerobot.datasets.io_utils import hf_transform_to_torch, flatten_dict from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, @@ -36,6 +36,7 @@ from lerobot.datasets.utils import ( DEFAULT_DATA_PATH, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, + DatasetInfo, ) from lerobot.datasets.video_utils import encode_video_frames from lerobot.utils.constants import DEFAULT_FEATURES @@ -157,7 +158,6 @@ def info_factory(features_factory): total_episodes: int = 0, total_frames: int = 0, total_tasks: int = 0, - total_videos: int = 0, chunks_size: int = DEFAULT_CHUNK_SIZE, data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB, video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB, @@ -166,24 +166,23 @@ def info_factory(features_factory): motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, use_videos: bool = True, - ) -> dict: + ) -> DatasetInfo: features = features_factory(motor_features, camera_features, use_videos) - return { - "codebase_version": codebase_version, - "robot_type": robot_type, - "total_episodes": total_episodes, - "total_frames": total_frames, - "total_tasks": total_tasks, - "total_videos": total_videos, - "chunks_size": chunks_size, - "data_files_size_in_mb": data_files_size_in_mb, - "video_files_size_in_mb": video_files_size_in_mb, - "fps": fps, - "splits": {}, - "data_path": data_path, - "video_path": video_path if use_videos else None, - "features": features, - } + return DatasetInfo( + codebase_version=codebase_version, + robot_type=robot_type, + total_episodes=total_episodes, + total_frames=total_frames, + total_tasks=total_tasks, + chunks_size=chunks_size, + data_files_size_in_mb=data_files_size_in_mb, + video_files_size_in_mb=video_files_size_in_mb, + fps=fps, + splits={}, + data_path=data_path, + video_path=video_path if use_videos else None, + features=features, + ) return _create_info @@ -333,12 +332,12 @@ def create_videos(info_factory, img_array_factory): total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks ) - video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"} + video_feats = {key: feats for key, feats in info.features.items() if feats["dtype"] == "video"} for key, ft in video_feats.items(): # create and save images with identifiable content tmp_dir = root / "tmp_images" tmp_dir.mkdir(parents=True, exist_ok=True) - for frame_index in range(info["total_frames"]): + for frame_index in range(info.total_frames): content = f"{key}-{frame_index}" img = img_array_factory(height=ft["shape"][0], width=ft["shape"][1], content=content) pil_img = PIL.Image.fromarray(img) @@ -348,7 +347,7 @@ def create_videos(info_factory, img_array_factory): video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0) video_path.parent.mkdir(parents=True, exist_ok=True) # Use the global fps from info, not video-specific fps which might not exist - encode_video_frames(tmp_dir, video_path, fps=info["fps"]) + encode_video_frames(tmp_dir, video_path, fps=info.fps) shutil.rmtree(tmp_dir) return _create_video_directory @@ -433,16 +432,16 @@ def lerobot_dataset_metadata_factory( if info is None: info = info_factory() if stats is None: - stats = stats_factory(features=info["features"]) + stats = stats_factory(features=info.features) if tasks is None: - tasks = tasks_factory(total_tasks=info["total_tasks"]) + tasks = tasks_factory(total_tasks=info.total_tasks) if episodes is None: - video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"] + video_keys = [key for key, ft in info.features.items() if ft["dtype"] == "video"] episodes = episodes_factory( - features=info["features"], - fps=info["fps"], - total_episodes=info["total_episodes"], - total_frames=info["total_frames"], + features=info.features, + fps=info.fps, + total_episodes=info.total_episodes, + total_frames=info.total_frames, video_keys=video_keys, tasks=tasks, ) @@ -503,23 +502,23 @@ def lerobot_dataset_factory( chunks_size=chunks_size, ) if stats is None: - stats = stats_factory(features=info["features"]) + stats = stats_factory(features=info.features) if tasks is None: - tasks = tasks_factory(total_tasks=info["total_tasks"]) + tasks = tasks_factory(total_tasks=info.total_tasks) if episodes_metadata is None: - video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"] + video_keys = [key for key, ft in info.features.items() if ft["dtype"] == "video"] episodes_metadata = episodes_factory( - features=info["features"], - fps=info["fps"], - total_episodes=info["total_episodes"], - total_frames=info["total_frames"], + features=info.features, + fps=info.fps, + total_episodes=info.total_episodes, + total_frames=info.total_frames, video_keys=video_keys, tasks=tasks, multi_task=multi_task, ) if hf_dataset is None: hf_dataset = hf_dataset_factory( - features=info["features"], tasks=tasks, episodes=episodes_metadata, fps=info["fps"] + features=info.features, tasks=tasks, episodes=episodes_metadata, fps=info.fps ) # Write data on disk diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 4333b91a3..2f521c766 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -62,19 +62,19 @@ def mock_snapshot_download_factory( if info is None: info = info_factory(data_files_size_in_mb=data_files_size_in_mb, chunks_size=chunks_size) if stats is None: - stats = stats_factory(features=info["features"]) + stats = stats_factory(features=info.features) if tasks is None: - tasks = tasks_factory(total_tasks=info["total_tasks"]) + tasks = tasks_factory(total_tasks=info.total_tasks) if episodes is None: episodes = episodes_factory( - features=info["features"], - fps=info["fps"], - total_episodes=info["total_episodes"], - total_frames=info["total_frames"], + features=info.features, + fps=info.fps, + total_episodes=info.total_episodes, + total_frames=info.total_frames, tasks=tasks, ) if hf_dataset is None: - hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"]) + hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info.fps) def _mock_snapshot_download( repo_id: str, # TODO(rcadene): repo_id should be used no? @@ -97,7 +97,7 @@ def mock_snapshot_download_factory( DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0), ] - video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"] + video_keys = [key for key, feats in info.features.items() if feats["dtype"] == "video"] for key in video_keys: all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))