fix(caching) remove cache dir when collecting a dataset with each call to load_episodes and load_hf_dataset

This commit is contained in:
Michel Aractingi
2025-09-08 12:44:43 +02:00
parent 952f455446
commit af79dda8d9
2 changed files with 21 additions and 15 deletions
+15 -15
View File
@@ -28,7 +28,6 @@ import pandas as pd
import PIL.Image import PIL.Image
import torch import torch
import torch.utils import torch.utils
from datasets import Dataset, concatenate_datasets
from huggingface_hub import HfApi, snapshot_download from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.errors import RevisionNotFoundError
@@ -49,6 +48,7 @@ from lerobot.datasets.utils import (
embed_images, embed_images,
flatten_dict, flatten_dict,
get_delta_indices, get_delta_indices,
get_hf_dataset_cache_dir,
get_hf_dataset_size_in_mb, get_hf_dataset_size_in_mb,
get_hf_features_from_features, get_hf_features_from_features,
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
@@ -271,7 +271,7 @@ class LeRobotDatasetMetadata:
""" """
# Convert buffer into HF Dataset # Convert buffer into HF Dataset
episode_dict = {key: [value] for key, value in episode_dict.items()} episode_dict = {key: [value] for key, value in episode_dict.items()}
ep_dataset = Dataset.from_dict(episode_dict) ep_dataset = datasets.Dataset.from_dict(episode_dict)
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
df = pd.DataFrame(ep_dataset) df = pd.DataFrame(ep_dataset)
num_frames = episode_dict["length"][0] num_frames = episode_dict["length"][0]
@@ -316,16 +316,13 @@ class LeRobotDatasetMetadata:
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(path, index=False) df.to_parquet(path, index=False)
# Update the Hugging Face dataset incrementally instead of reloading from disk if self.episodes is not None:
# This eliminates repeated load_episodes calls that cause cache bloat # Remove the episodes cache directory, necessary to avoid cache bloat
if self.episodes is None: cached_dir = get_hf_dataset_cache_dir(self.episodes)
self.episodes = load_episodes(self.root) if cached_dir is not None:
return shutil.rmtree(cached_dir)
# Remove columns from df that start with 'stats/' self.episodes = load_episodes(self.root)
df = df.drop(columns=[col for col in df.columns if col.startswith("stats/")])
new_episode_dataset = Dataset.from_pandas(df)
self.episodes = concatenate_datasets([self.episodes, new_episode_dataset])
def save_episode( def save_episode(
self, self,
@@ -1063,10 +1060,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
else: else:
df.to_parquet(path) df.to_parquet(path)
self.hf_dataset = ( if self.hf_dataset is not None:
concatenate_datasets([self.hf_dataset, ep_dataset]) if self.hf_dataset is not None else ep_dataset # Remove hf dataset cache directory, necessary to avoid cache bloat
) cached_dir = get_hf_dataset_cache_dir(self.hf_dataset)
self.hf_dataset.set_transform(hf_transform_to_torch) if cached_dir is not None:
shutil.rmtree(cached_dir)
self.hf_dataset = self.load_hf_dataset()
metadata = { metadata = {
"data/chunk_index": chunk_idx, "data/chunk_index": chunk_idx,
+6
View File
@@ -102,6 +102,12 @@ def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
return hf_ds.data.nbytes // (1024**2) return hf_ds.data.nbytes // (1024**2)
def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None:
if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0:
return None
return Path(hf_ds.cache_files[0]["filename"]).parents[2]
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
if file_idx == chunks_size - 1: if file_idx == chunks_size - 1:
file_idx = 0 file_idx = 0