fix(caching) remove cache dir when collecting a dataset with each call to load_episodes and load_hf_dataset

This commit is contained in:
Michel Aractingi
2025-09-08 12:44:43 +02:00
parent 952f455446
commit af79dda8d9
2 changed files with 21 additions and 15 deletions
+15 -15
View File
@@ -28,7 +28,6 @@ import pandas as pd
import PIL.Image
import torch
import torch.utils
from datasets import Dataset, concatenate_datasets
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
@@ -49,6 +48,7 @@ from lerobot.datasets.utils import (
embed_images,
flatten_dict,
get_delta_indices,
get_hf_dataset_cache_dir,
get_hf_dataset_size_in_mb,
get_hf_features_from_features,
get_parquet_file_size_in_mb,
@@ -271,7 +271,7 @@ class LeRobotDatasetMetadata:
"""
# Convert buffer into HF Dataset
episode_dict = {key: [value] for key, value in episode_dict.items()}
ep_dataset = Dataset.from_dict(episode_dict)
ep_dataset = datasets.Dataset.from_dict(episode_dict)
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
df = pd.DataFrame(ep_dataset)
num_frames = episode_dict["length"][0]
@@ -316,16 +316,13 @@ class LeRobotDatasetMetadata:
path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(path, index=False)
# Update the Hugging Face dataset incrementally instead of reloading from disk
# This eliminates repeated load_episodes calls that cause cache bloat
if self.episodes is None:
self.episodes = load_episodes(self.root)
return
if self.episodes is not None:
# Remove the episodes cache directory, necessary to avoid cache bloat
cached_dir = get_hf_dataset_cache_dir(self.episodes)
if cached_dir is not None:
shutil.rmtree(cached_dir)
# Remove columns from df that start with 'stats/'
df = df.drop(columns=[col for col in df.columns if col.startswith("stats/")])
new_episode_dataset = Dataset.from_pandas(df)
self.episodes = concatenate_datasets([self.episodes, new_episode_dataset])
self.episodes = load_episodes(self.root)
def save_episode(
self,
@@ -1063,10 +1060,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
df.to_parquet(path)
self.hf_dataset = (
concatenate_datasets([self.hf_dataset, ep_dataset]) if self.hf_dataset is not None else ep_dataset
)
self.hf_dataset.set_transform(hf_transform_to_torch)
if self.hf_dataset is not None:
# Remove hf dataset cache directory, necessary to avoid cache bloat
cached_dir = get_hf_dataset_cache_dir(self.hf_dataset)
if cached_dir is not None:
shutil.rmtree(cached_dir)
self.hf_dataset = self.load_hf_dataset()
metadata = {
"data/chunk_index": chunk_idx,
+6
View File
@@ -102,6 +102,12 @@ def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
return hf_ds.data.nbytes // (1024**2)
def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None:
if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0:
return None
return Path(hf_ds.cache_files[0]["filename"]).parents[2]
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
if file_idx == chunks_size - 1:
file_idx = 0