mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
723 lines
28 KiB
Python
723 lines
28 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import contextlib
|
|
from collections.abc import Callable
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import packaging.version
|
|
import pandas as pd
|
|
import pyarrow as pa
|
|
import pyarrow.parquet as pq
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from lerobot.configs import VideoEncoderConfig
|
|
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
|
|
from lerobot.utils.feature_utils import _validate_feature_names
|
|
from lerobot.utils.utils import flatten_dict
|
|
|
|
from .compute_stats import aggregate_stats
|
|
from .feature_utils import create_empty_dataset_info
|
|
from .io_utils import (
|
|
get_file_size_in_mb,
|
|
load_episodes,
|
|
load_info,
|
|
load_stats,
|
|
load_subtasks,
|
|
load_tasks,
|
|
write_info,
|
|
write_stats,
|
|
write_tasks,
|
|
)
|
|
from .utils import (
|
|
DEFAULT_EPISODES_PATH,
|
|
check_version_compatibility,
|
|
get_safe_version,
|
|
has_legacy_hub_download_metadata,
|
|
is_valid_version,
|
|
update_chunk_file_indices,
|
|
)
|
|
from .video_utils import get_video_info
|
|
|
|
CODEBASE_VERSION = "v3.0"
|
|
|
|
|
|
class LeRobotDatasetMetadata:
|
|
"""Metadata container for a LeRobot dataset.
|
|
|
|
Manages the ``info.json``, ``stats.json``, ``tasks.parquet``, and
|
|
``episodes/`` parquet files that describe a dataset's structure, content,
|
|
and statistics.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
repo_id: str,
|
|
root: str | Path | None = None,
|
|
revision: str | None = None,
|
|
force_cache_sync: bool = False,
|
|
metadata_buffer_size: int = 10,
|
|
):
|
|
"""Load or download metadata for an existing LeRobot dataset.
|
|
|
|
Attempts to load metadata from local disk. If files are missing or
|
|
``force_cache_sync`` is ``True``, downloads the ``meta/`` directory from
|
|
the Hub.
|
|
|
|
Args:
|
|
repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``).
|
|
root: Local directory for the dataset. When provided, Hub downloads
|
|
are materialized directly into this directory. When omitted,
|
|
existing local datasets are still looked up under
|
|
``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a
|
|
revision-safe snapshot cache under
|
|
``$HF_LEROBOT_HOME/hub``.
|
|
revision: Git revision (branch, tag, or commit hash). Defaults to
|
|
the current codebase version.
|
|
force_cache_sync: If ``True``, re-download metadata from the Hub
|
|
even when local files exist.
|
|
metadata_buffer_size: Number of episode metadata records to buffer
|
|
in memory before flushing to parquet.
|
|
"""
|
|
self.repo_id = repo_id
|
|
self.revision = revision if revision else CODEBASE_VERSION
|
|
self._requested_root = Path(root) if root is not None else None
|
|
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
|
|
self._pq_writer = None
|
|
self.latest_episode = None
|
|
self._metadata_buffer: list[dict] = []
|
|
self._metadata_buffer_size = metadata_buffer_size
|
|
self._finalized = False
|
|
|
|
try:
|
|
if force_cache_sync or (
|
|
self._requested_root is None and has_legacy_hub_download_metadata(self.root)
|
|
):
|
|
raise FileNotFoundError
|
|
self._load_metadata()
|
|
except (FileNotFoundError, NotADirectoryError):
|
|
if is_valid_version(self.revision):
|
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
|
|
|
self._pull_from_repo(allow_patterns="meta/")
|
|
self._load_metadata()
|
|
|
|
def _flush_metadata_buffer(self) -> None:
|
|
"""Write all buffered episode metadata to parquet file."""
|
|
if not hasattr(self, "_metadata_buffer") or len(self._metadata_buffer) == 0:
|
|
return
|
|
|
|
combined_dict = {}
|
|
for episode_dict in self._metadata_buffer:
|
|
for key, value in episode_dict.items():
|
|
if key not in combined_dict:
|
|
combined_dict[key] = []
|
|
# Extract value and serialize numpy arrays
|
|
# because PyArrow's from_pydict function doesn't support numpy arrays
|
|
val = value[0] if isinstance(value, list) else value
|
|
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
|
|
|
first_ep = self._metadata_buffer[0]
|
|
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
|
file_idx = first_ep["meta/episodes/file_index"][0]
|
|
|
|
table = pa.Table.from_pydict(combined_dict)
|
|
|
|
if not self._pq_writer:
|
|
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._pq_writer = pq.ParquetWriter(
|
|
path, schema=table.schema, compression="snappy", use_dictionary=True
|
|
)
|
|
|
|
self._pq_writer.write_table(table)
|
|
|
|
self.latest_episode = self._metadata_buffer[-1]
|
|
self._metadata_buffer.clear()
|
|
|
|
def _close_writer(self) -> None:
|
|
"""Close and cleanup the parquet writer if it exists."""
|
|
self._flush_metadata_buffer()
|
|
|
|
writer = getattr(self, "_pq_writer", None)
|
|
if writer is not None:
|
|
writer.close()
|
|
self._pq_writer = None
|
|
|
|
def finalize(self) -> None:
|
|
"""Flush metadata buffer and close the parquet writer.
|
|
|
|
Idempotent — safe to call multiple times.
|
|
"""
|
|
if getattr(self, "_finalized", False):
|
|
return
|
|
self._close_writer()
|
|
self._finalized = True
|
|
|
|
def __del__(self):
|
|
"""Safety net: flush and close parquet writer on garbage collection."""
|
|
# During interpreter shutdown, referenced objects may already be collected.
|
|
with contextlib.suppress(Exception):
|
|
self.finalize()
|
|
|
|
def _load_metadata(self):
|
|
self.info = load_info(self.root)
|
|
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
|
self.tasks = load_tasks(self.root)
|
|
self.subtasks = load_subtasks(self.root)
|
|
self.episodes = load_episodes(self.root)
|
|
self.stats = load_stats(self.root)
|
|
|
|
def ensure_readable(self) -> None:
|
|
"""Guarantee metadata is fully loaded for read operations.
|
|
|
|
Idempotent — when metadata is already in memory this is a single
|
|
``is None`` check. Call this before transitioning from write to
|
|
read mode on the same instance.
|
|
"""
|
|
if self.episodes is None:
|
|
self._load_metadata()
|
|
|
|
def filter_episodes(
|
|
self,
|
|
predicate: Callable[[dict], bool],
|
|
candidates: list[int] | None = None,
|
|
) -> list[int]:
|
|
"""Filter episodes whose metadata satisfies a given predicate.
|
|
|
|
Args:
|
|
predicate: Predicate over per-episode metadata rows used to select episodes.
|
|
candidates: Optional list of episode indices to restrict evaluation to.
|
|
|
|
Returns:
|
|
List of sorted episode indices that satisfy the predicate.
|
|
"""
|
|
self.ensure_readable()
|
|
if candidates is not None:
|
|
candidate_set = set(candidates)
|
|
combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731
|
|
else:
|
|
combined = predicate
|
|
filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False)
|
|
return sorted(int(idx) for idx in filtered["episode_index"])
|
|
|
|
def _pull_from_repo(
|
|
self,
|
|
allow_patterns: list[str] | str | None = None,
|
|
ignore_patterns: list[str] | str | None = None,
|
|
) -> None:
|
|
if self._requested_root is None:
|
|
self.root = Path(
|
|
snapshot_download(
|
|
self.repo_id,
|
|
repo_type="dataset",
|
|
revision=self.revision,
|
|
cache_dir=HF_LEROBOT_HUB_CACHE,
|
|
allow_patterns=allow_patterns,
|
|
ignore_patterns=ignore_patterns,
|
|
)
|
|
)
|
|
return
|
|
|
|
self._requested_root.mkdir(exist_ok=True, parents=True)
|
|
snapshot_download(
|
|
self.repo_id,
|
|
repo_type="dataset",
|
|
revision=self.revision,
|
|
local_dir=self._requested_root,
|
|
allow_patterns=allow_patterns,
|
|
ignore_patterns=ignore_patterns,
|
|
)
|
|
self.root = self._requested_root
|
|
|
|
@property
|
|
def url_root(self) -> str:
|
|
"""Hugging Face Hub URL root for this dataset."""
|
|
return f"hf://datasets/{self.repo_id}"
|
|
|
|
@property
|
|
def _version(self) -> packaging.version.Version:
|
|
"""Codebase version used to create this dataset."""
|
|
return packaging.version.parse(self.info.codebase_version)
|
|
|
|
def get_data_file_path(self, ep_index: int) -> Path:
|
|
"""Return the relative parquet file path for the given episode index.
|
|
|
|
Args:
|
|
ep_index: Zero-based episode index.
|
|
|
|
Returns:
|
|
Path to the parquet file containing this episode's data.
|
|
|
|
Raises:
|
|
IndexError: If ``ep_index`` is out of range.
|
|
"""
|
|
if self.episodes is None:
|
|
self.episodes = load_episodes(self.root)
|
|
if ep_index >= len(self.episodes):
|
|
raise IndexError(
|
|
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
|
)
|
|
ep = self.episodes[ep_index]
|
|
chunk_idx = ep["data/chunk_index"]
|
|
file_idx = ep["data/file_index"]
|
|
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
|
return Path(fpath)
|
|
|
|
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
|
"""Return the relative video file path for the given episode and video key.
|
|
|
|
Args:
|
|
ep_index: Zero-based episode index.
|
|
vid_key: Feature key identifying the video stream
|
|
(e.g. ``'observation.images.laptop'``).
|
|
|
|
Returns:
|
|
Path to the video file containing this episode's frames.
|
|
|
|
Raises:
|
|
IndexError: If ``ep_index`` is out of range.
|
|
"""
|
|
if self.episodes is None:
|
|
self.episodes = load_episodes(self.root)
|
|
if ep_index >= len(self.episodes):
|
|
raise IndexError(
|
|
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
|
)
|
|
ep = self.episodes[ep_index]
|
|
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
|
|
file_idx = ep[f"videos/{vid_key}/file_index"]
|
|
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
|
return Path(fpath)
|
|
|
|
@property
|
|
def data_path(self) -> str:
|
|
"""Formattable string for the parquet files."""
|
|
return self.info.data_path
|
|
|
|
@property
|
|
def video_path(self) -> str | None:
|
|
"""Formattable string for the video files."""
|
|
return self.info.video_path
|
|
|
|
@property
|
|
def robot_type(self) -> str | None:
|
|
"""Robot type used in recording this dataset."""
|
|
return self.info.robot_type
|
|
|
|
@property
|
|
def fps(self) -> int:
|
|
"""Frames per second used during data collection."""
|
|
return self.info.fps
|
|
|
|
@property
|
|
def features(self) -> dict[str, dict]:
|
|
"""All features contained in the dataset."""
|
|
return self.info.features
|
|
|
|
@property
|
|
def image_keys(self) -> list[str]:
|
|
"""Keys to access visual modalities stored as images."""
|
|
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
|
|
|
|
@property
|
|
def video_keys(self) -> list[str]:
|
|
"""Keys to access visual modalities stored as videos."""
|
|
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
|
|
|
@property
|
|
def depth_keys(self) -> list[str]:
|
|
"""Keys to access depth-map modalities stored as videos or images.
|
|
|
|
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
|
|
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
|
|
"""
|
|
|
|
def _is_depth(ft: dict) -> bool:
|
|
info = ft.get("info") or {}
|
|
video_info = ft.get("video_info") or {}
|
|
return (
|
|
info.get("is_depth_map", False)
|
|
or info.get("video.is_depth_map", False)
|
|
or video_info.get("video.is_depth_map", False)
|
|
)
|
|
|
|
return [key for key, ft in self.features.items() if _is_depth(ft)]
|
|
|
|
@property
|
|
def camera_keys(self) -> list[str]:
|
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
|
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
|
|
|
@property
|
|
def names(self) -> dict[str, list | dict]:
|
|
"""Names of the various dimensions of vector modalities."""
|
|
return {key: ft["names"] for key, ft in self.features.items()}
|
|
|
|
@property
|
|
def shapes(self) -> dict:
|
|
"""Shapes for the different features."""
|
|
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
|
|
|
|
@property
|
|
def total_episodes(self) -> int:
|
|
"""Total number of episodes available."""
|
|
return self.info.total_episodes
|
|
|
|
@property
|
|
def total_frames(self) -> int:
|
|
"""Total number of frames saved in this dataset."""
|
|
return self.info.total_frames
|
|
|
|
@property
|
|
def total_tasks(self) -> int:
|
|
"""Total number of different tasks performed in this dataset."""
|
|
return self.info.total_tasks
|
|
|
|
@property
|
|
def chunks_size(self) -> int:
|
|
"""Max number of files per chunk."""
|
|
return self.info.chunks_size
|
|
|
|
@property
|
|
def data_files_size_in_mb(self) -> int:
|
|
"""Max size of data file in mega bytes."""
|
|
return self.info.data_files_size_in_mb
|
|
|
|
@property
|
|
def video_files_size_in_mb(self) -> int:
|
|
"""Max size of video file in mega bytes."""
|
|
return self.info.video_files_size_in_mb
|
|
|
|
def get_task_index(self, task: str) -> int | None:
|
|
"""
|
|
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
|
otherwise return None.
|
|
"""
|
|
if task in self.tasks.index:
|
|
return int(self.tasks.loc[task].task_index)
|
|
else:
|
|
return None
|
|
|
|
def save_episode_tasks(self, tasks: list[str]):
|
|
"""Register tasks for the current episode and persist to disk.
|
|
|
|
New tasks that do not already exist in the dataset are assigned
|
|
sequential task indices and appended to the tasks parquet file.
|
|
|
|
Args:
|
|
tasks: List of unique task descriptions in natural language.
|
|
|
|
Raises:
|
|
ValueError: If ``tasks`` contains duplicates.
|
|
"""
|
|
if len(set(tasks)) != len(tasks):
|
|
raise ValueError(f"Tasks are not unique: {tasks}")
|
|
|
|
if self.tasks is None:
|
|
new_tasks = tasks
|
|
task_indices = range(len(tasks))
|
|
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
|
|
else:
|
|
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
|
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
|
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
|
|
self.tasks.loc[task] = task_idx
|
|
|
|
if len(new_tasks) > 0:
|
|
# Update on disk
|
|
write_tasks(self.tasks, self.root)
|
|
|
|
def _save_episode_metadata(self, episode_dict: dict) -> None:
|
|
"""Buffer episode metadata and write to parquet in batches for efficiency.
|
|
|
|
This function accumulates episode metadata in a buffer and flushes it when the buffer
|
|
reaches the configured size. This reduces I/O overhead by writing multiple episodes
|
|
at once instead of one row at a time.
|
|
|
|
Notes: We both need to update parquet files and HF dataset:
|
|
- `pandas` loads parquet file in RAM
|
|
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
|
or loads directly from pyarrow cache.
|
|
"""
|
|
# Convert to list format for each value
|
|
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
|
num_frames = episode_dict["length"][0]
|
|
|
|
if self.latest_episode is None:
|
|
# Initialize indices and frame count for a new dataset made of the first episode data
|
|
chunk_idx, file_idx = 0, 0
|
|
if self.episodes is not None and len(self.episodes) > 0:
|
|
# It means we are resuming recording, so we need to load the latest episode
|
|
# Update the indices to avoid overwriting the latest episode
|
|
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
|
|
file_idx = self.episodes[-1]["meta/episodes/file_index"]
|
|
latest_num_frames = self.episodes[-1]["dataset_to_index"]
|
|
episode_dict["dataset_from_index"] = [latest_num_frames]
|
|
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
|
|
|
|
# When resuming, move to the next file
|
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
|
else:
|
|
episode_dict["dataset_from_index"] = [0]
|
|
episode_dict["dataset_to_index"] = [num_frames]
|
|
|
|
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
|
episode_dict["meta/episodes/file_index"] = [file_idx]
|
|
else:
|
|
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
|
|
file_idx = self.latest_episode["meta/episodes/file_index"][0]
|
|
|
|
latest_path = (
|
|
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
|
if self._pq_writer is None
|
|
else self._pq_writer.where
|
|
)
|
|
|
|
if Path(latest_path).exists():
|
|
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
|
|
latest_num_frames = self.latest_episode["episode_index"][0]
|
|
|
|
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
|
|
|
|
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
|
|
# Size limit is reached, flush buffer and prepare new parquet file
|
|
self._flush_metadata_buffer()
|
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
|
self._close_writer()
|
|
|
|
# Update the existing pandas dataframe with new row
|
|
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
|
episode_dict["meta/episodes/file_index"] = [file_idx]
|
|
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
|
|
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
|
|
|
|
# Add to buffer
|
|
self._metadata_buffer.append(episode_dict)
|
|
self.latest_episode = episode_dict
|
|
|
|
if len(self._metadata_buffer) >= self._metadata_buffer_size:
|
|
self._flush_metadata_buffer()
|
|
|
|
def save_episode(
|
|
self,
|
|
episode_index: int,
|
|
episode_length: int,
|
|
episode_tasks: list[str],
|
|
episode_stats: dict[str, dict],
|
|
episode_metadata: dict,
|
|
) -> None:
|
|
"""Persist episode metadata, update dataset info, and aggregate stats.
|
|
|
|
Writes the episode's metadata to the buffered parquet writer, increments
|
|
the total episode/frame counters in ``info.json``, and merges the
|
|
episode's statistics into the running dataset statistics.
|
|
|
|
Args:
|
|
episode_index: Zero-based index of the episode being saved.
|
|
episode_length: Number of frames in this episode.
|
|
episode_tasks: List of task descriptions for this episode.
|
|
episode_stats: Per-feature statistics for this episode.
|
|
episode_metadata: Additional metadata (chunk/file indices, frame
|
|
ranges, video timestamps, etc.).
|
|
"""
|
|
episode_dict = {
|
|
"episode_index": episode_index,
|
|
"tasks": episode_tasks,
|
|
"length": episode_length,
|
|
}
|
|
episode_dict.update(episode_metadata)
|
|
episode_dict.update(flatten_dict({"stats": episode_stats}))
|
|
self._save_episode_metadata(episode_dict)
|
|
|
|
# Update info
|
|
self.info.total_episodes += 1
|
|
self.info.total_frames += episode_length
|
|
self.info.total_tasks = len(self.tasks)
|
|
self.info.splits = {"train": f"0:{self.info.total_episodes}"}
|
|
|
|
write_info(self.info, self.root)
|
|
|
|
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
|
write_stats(self.stats, self.root)
|
|
|
|
def update_video_info(
|
|
self,
|
|
video_key: str | None = None,
|
|
video_encoder: VideoEncoderConfig | None = None,
|
|
) -> None:
|
|
"""Populate per-feature video info in ``info.json``.
|
|
|
|
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
|
|
|
Args:
|
|
video_key: If provided, only update this video key. Otherwise update
|
|
all video keys in the dataset.
|
|
camera_encoder: Encoder configuration used to produce the
|
|
videos. When provided, its fields are recorded as
|
|
``video.<field>`` entries alongside the stream-derived
|
|
``video.*`` entries (see :func:`get_video_info`).
|
|
"""
|
|
if video_key is not None and video_key not in self.video_keys:
|
|
raise ValueError(f"Video key {video_key} not found in dataset")
|
|
|
|
video_keys = [video_key] if video_key is not None else self.video_keys
|
|
for key in video_keys:
|
|
existing = self.features[key].get("info") or {}
|
|
# Skip only if real video info has already been written. The ``is_depth_map`` entry (created at feature creation) is not blocking.
|
|
if set(existing.keys()) - {"is_depth_map"}:
|
|
continue
|
|
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
|
new_info = get_video_info(video_path, video_encoder=video_encoder)
|
|
self.info.features[key]["info"] = {**existing, **new_info}
|
|
|
|
def update_chunk_settings(
|
|
self,
|
|
chunks_size: int | None = None,
|
|
data_files_size_in_mb: int | None = None,
|
|
video_files_size_in_mb: int | None = None,
|
|
) -> None:
|
|
"""Update chunk and file size settings after dataset creation.
|
|
|
|
This allows users to customize storage organization without modifying the constructor.
|
|
These settings control how episodes are chunked and how large files can grow before
|
|
creating new ones.
|
|
|
|
Args:
|
|
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
|
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
|
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
|
"""
|
|
if chunks_size is not None:
|
|
if chunks_size <= 0:
|
|
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
|
self.info.chunks_size = chunks_size
|
|
|
|
if data_files_size_in_mb is not None:
|
|
if data_files_size_in_mb <= 0:
|
|
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
|
self.info.data_files_size_in_mb = data_files_size_in_mb
|
|
|
|
if video_files_size_in_mb is not None:
|
|
if video_files_size_in_mb <= 0:
|
|
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
|
self.info.video_files_size_in_mb = video_files_size_in_mb
|
|
|
|
# Update the info file on disk
|
|
write_info(self.info, self.root)
|
|
|
|
def get_chunk_settings(self) -> dict[str, int]:
|
|
"""Get current chunk and file size settings.
|
|
|
|
Returns:
|
|
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
|
|
"""
|
|
return {
|
|
"chunks_size": self.chunks_size,
|
|
"data_files_size_in_mb": self.data_files_size_in_mb,
|
|
"video_files_size_in_mb": self.video_files_size_in_mb,
|
|
}
|
|
|
|
def __repr__(self):
|
|
feature_keys = list(self.features)
|
|
return (
|
|
f"{self.__class__.__name__}({{\n"
|
|
f" Repository ID: '{self.repo_id}',\n"
|
|
f" Total episodes: '{self.total_episodes}',\n"
|
|
f" Total frames: '{self.total_frames}',\n"
|
|
f" Features: '{feature_keys}',\n"
|
|
"})',\n"
|
|
)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
repo_id: str,
|
|
fps: int,
|
|
features: dict,
|
|
robot_type: str | None = None,
|
|
root: str | Path | None = None,
|
|
use_videos: bool = True,
|
|
metadata_buffer_size: int = 10,
|
|
chunks_size: int | None = None,
|
|
data_files_size_in_mb: int | None = None,
|
|
video_files_size_in_mb: int | None = None,
|
|
) -> "LeRobotDatasetMetadata":
|
|
"""Create metadata for a new LeRobot dataset from scratch.
|
|
|
|
Initializes the ``info.json`` file on disk with the provided feature
|
|
schema and dataset settings. No episode data is written yet.
|
|
|
|
Args:
|
|
repo_id: Repository identifier (e.g. ``'user/my_dataset'``).
|
|
fps: Frames per second used during data collection.
|
|
features: Feature specification dict mapping feature names to their
|
|
type/shape metadata.
|
|
robot_type: Optional robot type string stored in metadata.
|
|
root: Local directory for the dataset. Defaults to
|
|
``$HF_LEROBOT_HOME/{repo_id}``. Must not already exist.
|
|
use_videos: If ``True``, visual modalities are encoded as MP4 videos.
|
|
metadata_buffer_size: Number of episode metadata records to buffer
|
|
before flushing to parquet.
|
|
chunks_size: Max number of files per chunk directory. ``None`` uses
|
|
the default.
|
|
data_files_size_in_mb: Max parquet file size in MB. ``None`` uses the
|
|
default.
|
|
video_files_size_in_mb: Max video file size in MB. ``None`` uses the
|
|
default.
|
|
|
|
Returns:
|
|
A new :class:`LeRobotDatasetMetadata` instance.
|
|
"""
|
|
obj = cls.__new__(cls)
|
|
obj.repo_id = repo_id
|
|
obj._requested_root = Path(root) if root is not None else None
|
|
obj.root = obj._requested_root if obj._requested_root is not None else HF_LEROBOT_HOME / repo_id
|
|
|
|
obj.root.mkdir(parents=True, exist_ok=False)
|
|
|
|
features = {**features, **DEFAULT_FEATURES}
|
|
_validate_feature_names(features)
|
|
|
|
obj.tasks = None
|
|
obj.subtasks = None
|
|
obj.episodes = None
|
|
obj.stats = None
|
|
obj.info = create_empty_dataset_info(
|
|
CODEBASE_VERSION,
|
|
fps,
|
|
features,
|
|
use_videos,
|
|
robot_type,
|
|
chunks_size,
|
|
data_files_size_in_mb,
|
|
video_files_size_in_mb,
|
|
)
|
|
if len(obj.video_keys) > 0 and not use_videos:
|
|
raise ValueError(
|
|
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
|
"Either remove video features from the features dict, or set 'use_videos=True'."
|
|
)
|
|
write_info(obj.info, obj.root)
|
|
obj.revision = None
|
|
obj._pq_writer = None
|
|
obj.latest_episode = None
|
|
obj._metadata_buffer = []
|
|
obj._metadata_buffer_size = metadata_buffer_size
|
|
obj._finalized = False
|
|
return obj
|