Remove dataset consolidate (#752)

This commit is contained in:
Simon Alibert
2025-02-19 16:02:54 +01:00
committed by GitHub
parent 6fe42a72db
commit 969ef745a2
6 changed files with 93 additions and 128 deletions
+31 -70
View File
@@ -39,7 +39,6 @@ from lerobot.common.datasets.utils import (
append_jsonlines,
backward_compatible_episodes_stats,
check_delta_timestamps,
check_frame_features,
check_timestamps_sync,
check_version_compatibility,
create_empty_dataset_info,
@@ -55,6 +54,8 @@ from lerobot.common.datasets.utils import (
load_info,
load_stats,
load_tasks,
validate_episode_buffer,
validate_frame,
write_episode,
write_episode_stats,
write_info,
@@ -256,6 +257,9 @@ class LeRobotDatasetMetadata:
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root)
episode_dict = {
@@ -270,7 +274,7 @@ class LeRobotDatasetMetadata:
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
write_episode_stats(episode_index, episode_stats, self.root)
def write_video_info(self) -> None:
def update_video_info(self) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
@@ -280,8 +284,6 @@ class LeRobotDatasetMetadata:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["features"][key]["info"] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
def __repr__(self):
feature_keys = list(self.features)
return (
@@ -506,9 +508,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Available stats implies all videos have been encoded and dataset is iterable
self.consolidated = self.meta.stats is not None
def push_to_hub(
self,
branch: str | None = None,
@@ -519,13 +518,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
allow_patterns: list[str] | str | None = None,
**card_kwargs,
) -> None:
if not self.consolidated:
logging.warning(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
"Consolidating first."
)
self.consolidate()
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")
@@ -779,7 +771,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
check_frame_features(frame, self.features)
validate_frame(frame, self.features)
if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
@@ -815,41 +807,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["size"] += 1
def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
def save_episode(self, episode_data: dict | None = None) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub.
This will save to disk the current episode in self.episode_buffer.
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding.
Args:
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
None.
"""
if not episode_data:
episode_buffer = self.episode_buffer
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
if episode_index != self.meta.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_length == 0:
raise ValueError(
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
if not set(episode_buffer.keys()) == set(self.features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
)
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
@@ -875,16 +851,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_stats = compute_episode_stats(episode_buffer, self.features)
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
if encode_videos and len(self.meta.video_keys) > 0:
if len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key]
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
# delete images
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer()
self.consolidated = False
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
@@ -959,28 +948,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths
def consolidate(self, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
if len(self.meta.video_keys) > 0:
self.encode_videos()
self.meta.write_video_info()
if not keep_image_files:
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
self.consolidated = True
@classmethod
def create(
cls,
@@ -1019,12 +986,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj.create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
# self.consolidate().
obj.consolidated = True
obj.episodes = None
obj.hf_dataset = None
obj.image_transforms = None