mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
Commit before episodes episodes_stats merging
This commit is contained in:
@@ -17,15 +17,16 @@ import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import PIL.Image
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from datasets import concatenate_datasets, load_dataset, Dataset
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
@@ -34,37 +35,57 @@ from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_EPISODES_STATS_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
EPISODES_DIR,
|
||||
EPISODES_STATS_DIR,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
LEGACY_TASKS_PATH,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
concat_video_files,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
embed_images,
|
||||
get_chunk_file_indices,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_dataset_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_latest_parquet_path,
|
||||
get_latest_video_path,
|
||||
get_parquet_num_frames,
|
||||
get_pd_dataframe_size_in_mb,
|
||||
get_safe_version,
|
||||
get_video_duration_in_s,
|
||||
get_video_size_in_mb,
|
||||
hf_transform_to_torch,
|
||||
is_valid_version,
|
||||
legacy_load_episodes,
|
||||
legacy_load_episodes_stats,
|
||||
load_episodes,
|
||||
load_episodes_stats,
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
legacy_load_tasks,
|
||||
load_tasks,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
legacy_write_episode_stats,
|
||||
write_info,
|
||||
write_json,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.common.datasets.v30.convert_dataset_v21_to_v30 import get_parquet_file_size_in_mb
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames_torchvision,
|
||||
@@ -105,12 +126,9 @@ class LeRobotDatasetMetadata:
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
# TODO(rcadene): https://huggingface.slack.com/archives/C02V51Q3800/p1743517952388249?thread_ts=1742896075.499119&cid=C02V51Q3800
|
||||
# self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -132,17 +150,19 @@ class LeRobotDatasetMetadata:
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
||||
chunk_idx = self.episodes[f"data/chunk_index"][ep_index]
|
||||
file_idx = self.episodes[f"data/file_index"][ep_index]
|
||||
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
chunk_idx = self.episodes[f"{vid_key}/chunk_index"][ep_index]
|
||||
file_idx = self.episodes[f"{vid_key}/file_index"][ep_index]
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
return ep_index // self.chunks_size
|
||||
# def get_episode_chunk(self, ep_index: int) -> int:
|
||||
# return ep_index // self.chunks_size
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
@@ -209,40 +229,85 @@ class LeRobotDatasetMetadata:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def total_chunks(self) -> int:
|
||||
"""Total number of chunks (groups of episodes)."""
|
||||
return self.info["total_chunks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of episodes per chunk."""
|
||||
"""Max number of files per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def files_size_in_mb(self) -> int:
|
||||
"""Max size of file in mega bytes."""
|
||||
return self.info["files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise return None.
|
||||
"""
|
||||
return self.task_to_task_index.get(task, None)
|
||||
return self.tasks.index[task] if task in self.tasks.index else None
|
||||
|
||||
def has_task(self, task: str) -> bool:
|
||||
return task in self.task_to_task_index
|
||||
|
||||
def add_task(self, task: str):
|
||||
"""
|
||||
Given a task in natural language, add it to the dictionary of tasks.
|
||||
"""
|
||||
if task in self.task_to_task_index:
|
||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||
def save_episode_tasks(self, tasks: list[str]):
|
||||
new_tasks = [task for task in tasks if not self.has_task(task)]
|
||||
|
||||
task_index = self.info["total_tasks"]
|
||||
self.task_to_task_index[task] = task_index
|
||||
self.tasks[task_index] = task
|
||||
self.info["total_tasks"] += 1
|
||||
for task in new_tasks:
|
||||
task_index = len(self.tasks)
|
||||
self.tasks.loc[task] = task_index
|
||||
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
if len(new_tasks) > 0:
|
||||
# Update on disk
|
||||
write_tasks(self.tasks, self.root)
|
||||
|
||||
def _save_episode(self, episode_dict: dict):
|
||||
ep_dataset = Dataset.from_dict(episode_dict)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
|
||||
# Access latest parquet file information
|
||||
latest_path = get_latest_parquet_path(self.root / EPISODES_DIR)
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
||||
# Create new parquet file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_df.to_parquet(new_path, index=False)
|
||||
else:
|
||||
# Update latest parquet file with new row
|
||||
ep_df = pd.DataFrame(ep_dataset)
|
||||
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
|
||||
latest_df.to_parquet(latest_path, index=False)
|
||||
|
||||
# Update the Hugging Face dataset by reloading it.
|
||||
# This process should be fast because only the latest Parquet file has been modified.
|
||||
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
|
||||
self.episodes = load_episodes(self.root)
|
||||
|
||||
def _save_episode_stats(self, episodes_stats: dict):
|
||||
ep_dataset = Dataset.from_dict(episodes_stats)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
|
||||
# Access latest parquet file information
|
||||
latest_path = get_latest_parquet_path(self.root / EPISODES_STATS_DIR)
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
||||
# Create new parquet file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.root / DEFAULT_EPISODES_STATS_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_df.to_parquet(new_path, index=False)
|
||||
else:
|
||||
# Update latest parquet file with new row
|
||||
ep_df = pd.DataFrame(ep_dataset)
|
||||
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
|
||||
latest_df.to_parquet(latest_path, index=False)
|
||||
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
@@ -250,19 +315,14 @@ class LeRobotDatasetMetadata:
|
||||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
episode_metadata: dict,
|
||||
) -> None:
|
||||
# Update info
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
@@ -270,12 +330,12 @@ class LeRobotDatasetMetadata:
|
||||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
}
|
||||
self.episodes[episode_index] = episode_dict
|
||||
write_episode(episode_dict, self.root)
|
||||
episode_dict.update(episode_metadata)
|
||||
self._save_episode(episode_dict)
|
||||
self._save_episode_stats(episode_stats)
|
||||
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
# TODO: write stats
|
||||
|
||||
def update_video_info(self) -> None:
|
||||
"""
|
||||
@@ -340,8 +400,11 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.tasks = None
|
||||
obj.episodes_stats = None
|
||||
obj.episodes = None
|
||||
# TODO(rcadene) stats
|
||||
obj.stats = {}
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
@@ -486,29 +549,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Load actual data
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
@@ -592,11 +643,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
files = None
|
||||
if self.episodes is not None:
|
||||
files = self.get_episodes_file_paths()
|
||||
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
@@ -609,31 +659,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
# episodes are stored in the same files, so we return unique paths only
|
||||
fpaths = list(set(fpaths))
|
||||
return fpaths
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
path = str(self.root / "data")
|
||||
# TODO(rcadene): load_dataset convert parquet to arrow.
|
||||
# set num_proc to accelerate this conversion
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
hf_dataset = load_nested_dataset(self.root / "data")
|
||||
hf_dataset.set_format("torch")
|
||||
return hf_dataset
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
features = get_hf_features_from_features(self.features)
|
||||
ft_dict = {col: [] for col in features}
|
||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
hf_dataset.set_format("torch")
|
||||
return hf_dataset
|
||||
|
||||
@property
|
||||
@@ -664,15 +704,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
ep_start = self.meta.episodes["data/from_index"][ep_idx]
|
||||
ep_end = self.meta.episodes["data/to_index"][ep_idx]
|
||||
query_indices = {
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -687,7 +727,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
query_timestamps[key] = timestamps.tolist()
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
@@ -695,7 +735,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
key: self.hf_dataset.select(q_idx)[key]
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
@@ -708,9 +748,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||
# Thus we load the start timestamp of the episode on this mp4 and,
|
||||
# shift the query timestamp accordingly.
|
||||
from_timestamp = self.meta.episodes[f"{vid_key}/from_timestamp"][ep_idx]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
video_path, shifted_query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
@@ -749,7 +795,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
if self.meta.tasks["task_index"][task_idx] != task_idx:
|
||||
raise ValueError("Sanity check on task index failed.")
|
||||
item["task"] = self.meta.tasks["task"][task_idx]
|
||||
|
||||
return item
|
||||
|
||||
@@ -779,6 +827,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
if self.image_writer is None:
|
||||
@@ -858,11 +909,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Add new tasks to the tasks dictionary
|
||||
for task in episode_tasks:
|
||||
task_index = self.meta.get_task_index(task)
|
||||
if task_index is None:
|
||||
self.meta.add_task(task)
|
||||
# Update tasks and task indices with new tasks if any
|
||||
self.meta.save_episode_tasks(episode_tasks)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
@@ -874,51 +922,107 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
ep_metadata = self._save_episode_data(episode_buffer, episode_index)
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# `meta.save_episode` neeed to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
|
||||
# `meta.save_episode` be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
|
||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
check_timestamps_sync(
|
||||
episode_buffer["timestamp"],
|
||||
episode_buffer["episode_index"],
|
||||
ep_data_index_np,
|
||||
self.fps,
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
# TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index
|
||||
# ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
# ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
# check_timestamps_sync(
|
||||
# episode_buffer["timestamp"],
|
||||
# episode_buffer["episode_index"],
|
||||
# ep_data_index_np,
|
||||
# self.fps,
|
||||
# self.tolerance_s,
|
||||
# )
|
||||
|
||||
# TODO(rcadene): images are also deleted in clear_episode_buffer
|
||||
# delete images
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
if not episode_data:
|
||||
# Reset episode buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
def _save_episode_data(self, episode_buffer: dict) -> None:
|
||||
# Convert buffer into HF Dataset
|
||||
ep_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_dataset.to_parquet(ep_data_path)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
ep_num_frames = len(ep_dataset)
|
||||
|
||||
# Access latest parquet file information
|
||||
latest_path = get_latest_parquet_path(self.root / "data")
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
latest_num_frames = get_parquet_num_frames(latest_path)
|
||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
||||
# Create new parquet file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_df.to_parquet(new_path, index=False)
|
||||
else:
|
||||
# Update latest parquet file with new rows
|
||||
ep_df = pd.DataFrame(ep_dataset)
|
||||
latest_df = pd.concat([latest_df, ep_df], ignore_index=True) # RAM
|
||||
latest_df.to_parquet(latest_path, index=False)
|
||||
|
||||
# Update the Hugging Face dataset by reloading it.
|
||||
# This process should be fast because only the latest Parquet file has been modified.
|
||||
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
metadata = {
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
"data/from_index": latest_num_frames,
|
||||
"data/to_index": latest_num_frames + ep_num_frames,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(self, video_key: str, episode_index: int):
|
||||
# Encode episode frames into a temporary video
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
# Access latest video file information
|
||||
latest_path = get_latest_video_path(self.root / "videos", video_key)
|
||||
latest_size_in_mb = get_video_size_in_mb(latest_path)
|
||||
latest_duration_in_s = get_video_duration_in_s(latest_path)
|
||||
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
|
||||
# Move temporary episode video to a new video file in the dataset
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.meta.video_path.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_path.replace(new_path)
|
||||
else:
|
||||
# Update latest video file
|
||||
concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx)
|
||||
|
||||
metadata = {
|
||||
"episode_index": episode_index,
|
||||
f"{video_key}/chunk_index": chunk_idx,
|
||||
f"{video_key}/file_index": file_idx,
|
||||
f"{video_key}/from_timestamp": latest_duration_in_s,
|
||||
f"{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self) -> None:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
@@ -958,34 +1062,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# TODO(rcadene): this method is currently not used
|
||||
# def encode_videos(self) -> None:
|
||||
# """
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
# Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
# """
|
||||
# for ep_idx in range(self.meta.total_episodes):
|
||||
# self.encode_episode_videos(ep_idx)
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
for ep_idx in range(self.meta.total_episodes):
|
||||
self.encode_episode_videos(ep_idx)
|
||||
|
||||
def encode_episode_videos(self, episode_index: int) -> dict:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
video_paths = {}
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
video_paths[key] = str(video_path)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
return video_paths
|
||||
temp_path = Path(tempfile.mkdtemp()) / f"{video_key}_{episode_index:3d}.mp4"
|
||||
img_dir = self._get_image_file_dir(episode_index, video_key)
|
||||
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
|
||||
return temp_path
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -1030,7 +1126,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
||||
|
||||
Reference in New Issue
Block a user