mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
Remove dataset consolidate (#752)
This commit is contained in:
@@ -39,7 +39,6 @@ from lerobot.common.datasets.utils import (
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_frame_features,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_empty_dataset_info,
|
||||
@@ -55,6 +54,8 @@ from lerobot.common.datasets.utils import (
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
@@ -256,6 +257,9 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
@@ -270,7 +274,7 @@ class LeRobotDatasetMetadata:
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
|
||||
def write_video_info(self) -> None:
|
||||
def update_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
@@ -280,8 +284,6 @@ class LeRobotDatasetMetadata:
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
@@ -506,9 +508,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Available stats implies all videos have been encoded and dataset is iterable
|
||||
self.consolidated = self.meta.stats is not None
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
@@ -519,13 +518,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
**card_kwargs,
|
||||
) -> None:
|
||||
if not self.consolidated:
|
||||
logging.warning(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
|
||||
"Consolidating first."
|
||||
)
|
||||
self.consolidate()
|
||||
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
@@ -779,7 +771,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if isinstance(frame[name], torch.Tensor):
|
||||
frame[name] = frame[name].numpy()
|
||||
|
||||
check_frame_features(frame, self.features)
|
||||
validate_frame(frame, self.features)
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
@@ -815,41 +807,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
This will save to disk the current episode in self.episode_buffer.
|
||||
|
||||
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
|
||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
Args:
|
||||
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
"""
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
|
||||
# size and task are special cases that won't be added to hf_dataset
|
||||
episode_length = episode_buffer.pop("size")
|
||||
tasks = episode_buffer.pop("task")
|
||||
episode_tasks = list(set(tasks))
|
||||
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
if episode_index != self.meta.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_length == 0:
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
|
||||
if not set(episode_buffer.keys()) == set(self.features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
|
||||
)
|
||||
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
@@ -875,16 +851,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
|
||||
if encode_videos and len(self.meta.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
# delete images
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
@@ -959,28 +948,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
self.meta.write_video_info()
|
||||
|
||||
if not keep_image_files:
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
self.consolidated = True
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -1019,12 +986,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj.create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||
# self.consolidate().
|
||||
obj.consolidated = True
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = None
|
||||
obj.image_transforms = None
|
||||
|
||||
Reference in New Issue
Block a user