diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index c4686f968..9036827fa 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +import gc import logging import shutil import tempfile @@ -314,6 +315,13 @@ class LeRobotDatasetMetadata: # Update the Hugging Face dataset by reloading it. # This process should be fast because only the latest Parquet file has been modified. # Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache. + + # Explicitly delete old dataset to free memory before reloading + if hasattr(self, 'episodes') and self.episodes is not None: + del self.episodes + self.episodes = None + gc.collect() + self.episodes = load_episodes(self.root) def save_episode( @@ -451,7 +459,7 @@ class LeRobotDataset(torch.utils.data.Dataset): root: str | Path | None = None, episodes: list[int] | None = None, image_transforms: Callable | None = None, - delta_timestamps: dict[list[float]] | None = None, + delta_timestamps: dict[str, list[float]] | None = None, tolerance_s: float = 1e-4, revision: str | None = None, force_cache_sync: bool = False, @@ -1051,6 +1059,13 @@ class LeRobotDataset(torch.utils.data.Dataset): # Update the Hugging Face dataset by reloading it. # This process should be fast because only the latest Parquet file has been modified. # Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache. + + # Explicitly delete old dataset to free memory before reloading + if hasattr(self, 'hf_dataset') and self.hf_dataset is not None: + del self.hf_dataset + self.hf_dataset = None + gc.collect() + self.hf_dataset = self.load_hf_dataset() metadata = { @@ -1216,7 +1231,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): root: str | Path | None = None, episodes: dict | None = None, image_transforms: Callable | None = None, - delta_timestamps: dict[list[float]] | None = None, + delta_timestamps: dict[str, list[float]] | None = None, tolerances_s: dict | None = None, download_videos: bool = True, video_backend: str | None = None,