fix(memory explosion) added delete to episodes and hf_dataset everytime we reload while collecting a dataset ot avoid memroy explosion

This commit is contained in:
Michel Aractingi
2025-09-03 15:31:28 +02:00
parent 2a3d62259e
commit fdccf7774b
+17 -2
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import gc
import logging
import shutil
import tempfile
@@ -314,6 +315,13 @@ class LeRobotDatasetMetadata:
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
# Explicitly delete old dataset to free memory before reloading
if hasattr(self, 'episodes') and self.episodes is not None:
del self.episodes
self.episodes = None
gc.collect()
self.episodes = load_episodes(self.root)
def save_episode(
@@ -451,7 +459,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
root: str | Path | None = None,
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
revision: str | None = None,
force_cache_sync: bool = False,
@@ -1051,6 +1059,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
# Explicitly delete old dataset to free memory before reloading
if hasattr(self, 'hf_dataset') and self.hf_dataset is not None:
del self.hf_dataset
self.hf_dataset = None
gc.collect()
self.hf_dataset = self.load_hf_dataset()
metadata = {
@@ -1216,7 +1231,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root: str | Path | None = None,
episodes: dict | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
video_backend: str | None = None,