mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
fix(caching) remove cache dir when collecting a dataset with each call to load_episodes and load_hf_dataset
This commit is contained in:
@@ -28,7 +28,6 @@ import pandas as pd
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
from datasets import Dataset, concatenate_datasets
|
|
||||||
from huggingface_hub import HfApi, snapshot_download
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
from huggingface_hub.constants import REPOCARD_NAME
|
from huggingface_hub.constants import REPOCARD_NAME
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
@@ -49,6 +48,7 @@ from lerobot.datasets.utils import (
|
|||||||
embed_images,
|
embed_images,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
|
get_hf_dataset_cache_dir,
|
||||||
get_hf_dataset_size_in_mb,
|
get_hf_dataset_size_in_mb,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
@@ -271,7 +271,7 @@ class LeRobotDatasetMetadata:
|
|||||||
"""
|
"""
|
||||||
# Convert buffer into HF Dataset
|
# Convert buffer into HF Dataset
|
||||||
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
||||||
ep_dataset = Dataset.from_dict(episode_dict)
|
ep_dataset = datasets.Dataset.from_dict(episode_dict)
|
||||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||||
df = pd.DataFrame(ep_dataset)
|
df = pd.DataFrame(ep_dataset)
|
||||||
num_frames = episode_dict["length"][0]
|
num_frames = episode_dict["length"][0]
|
||||||
@@ -316,16 +316,13 @@ class LeRobotDatasetMetadata:
|
|||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
df.to_parquet(path, index=False)
|
df.to_parquet(path, index=False)
|
||||||
|
|
||||||
# Update the Hugging Face dataset incrementally instead of reloading from disk
|
if self.episodes is not None:
|
||||||
# This eliminates repeated load_episodes calls that cause cache bloat
|
# Remove the episodes cache directory, necessary to avoid cache bloat
|
||||||
if self.episodes is None:
|
cached_dir = get_hf_dataset_cache_dir(self.episodes)
|
||||||
self.episodes = load_episodes(self.root)
|
if cached_dir is not None:
|
||||||
return
|
shutil.rmtree(cached_dir)
|
||||||
|
|
||||||
# Remove columns from df that start with 'stats/'
|
self.episodes = load_episodes(self.root)
|
||||||
df = df.drop(columns=[col for col in df.columns if col.startswith("stats/")])
|
|
||||||
new_episode_dataset = Dataset.from_pandas(df)
|
|
||||||
self.episodes = concatenate_datasets([self.episodes, new_episode_dataset])
|
|
||||||
|
|
||||||
def save_episode(
|
def save_episode(
|
||||||
self,
|
self,
|
||||||
@@ -1063,10 +1060,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
else:
|
else:
|
||||||
df.to_parquet(path)
|
df.to_parquet(path)
|
||||||
|
|
||||||
self.hf_dataset = (
|
if self.hf_dataset is not None:
|
||||||
concatenate_datasets([self.hf_dataset, ep_dataset]) if self.hf_dataset is not None else ep_dataset
|
# Remove hf dataset cache directory, necessary to avoid cache bloat
|
||||||
)
|
cached_dir = get_hf_dataset_cache_dir(self.hf_dataset)
|
||||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
if cached_dir is not None:
|
||||||
|
shutil.rmtree(cached_dir)
|
||||||
|
|
||||||
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"data/chunk_index": chunk_idx,
|
"data/chunk_index": chunk_idx,
|
||||||
|
|||||||
@@ -102,6 +102,12 @@ def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
|||||||
return hf_ds.data.nbytes // (1024**2)
|
return hf_ds.data.nbytes // (1024**2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None:
|
||||||
|
if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0:
|
||||||
|
return None
|
||||||
|
return Path(hf_ds.cache_files[0]["filename"]).parents[2]
|
||||||
|
|
||||||
|
|
||||||
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
|
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
|
||||||
if file_idx == chunks_size - 1:
|
if file_idx == chunks_size - 1:
|
||||||
file_idx = 0
|
file_idx = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user