diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index 59aaa40e5..d7289ac48 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -74,6 +74,8 @@ class DatasetReader: self.episodes = episodes self._tolerance_s = tolerance_s self._video_backend = video_backend + if image_transforms is not None and not callable(image_transforms): + raise TypeError("image_transforms must be callable or None.") self._image_transforms = image_transforms self._return_uint8 = return_uint8 @@ -86,6 +88,16 @@ class DatasetReader: check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) + def set_image_transforms(self, image_transforms: Callable | None) -> None: + """Replace the transform applied to visual observations.""" + if image_transforms is not None and not callable(image_transforms): + raise TypeError("image_transforms must be callable or None.") + self._image_transforms = image_transforms + + def clear_image_transforms(self) -> None: + """Remove the transform applied to visual observations.""" + self._image_transforms = None + def try_load(self) -> bool: """Attempt to load from local cache. Returns True if data is sufficient.""" try: diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index d0dcf087d..d1e65fef1 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -201,8 +201,6 @@ class LeRobotDataset(torch.utils.data.Dataset): super().__init__() self.repo_id = repo_id self._requested_root = Path(root) if root else None - self.reader = None - self.set_image_transforms(image_transforms) self.delta_timestamps = delta_timestamps self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION @@ -249,6 +247,7 @@ class LeRobotDataset(torch.utils.data.Dataset): image_transforms=image_transforms, return_uint8=self._return_uint8, ) + self.image_transforms = image_transforms # Load actual data if force_cache_sync or not self.reader.try_load(): @@ -505,15 +504,14 @@ class LeRobotDataset(torch.utils.data.Dataset): def set_image_transforms(self, image_transforms: Callable | None) -> None: """Replace the transform applied to visual observations.""" - if image_transforms is not None and not callable(image_transforms): - raise TypeError("image_transforms must be callable or None.") + self._ensure_reader().set_image_transforms(image_transforms) self.image_transforms = image_transforms - if self.reader is not None: - self.reader._image_transforms = image_transforms def clear_image_transforms(self) -> None: """Remove the transform applied to visual observations.""" - self.set_image_transforms(None) + if self.reader is not None: + self.reader.set_image_transforms(None) + self.image_transforms = None # ── Hub methods (stay on facade) ──────────────────────────────────