Compare commits

...

1 Commits

Author SHA1 Message Date
CarolinePascal 06858a1b40 fix(image transforms): cleaning up image_transforms implementation in LeRobotDataset 2026-06-16 19:32:44 +02:00
2 changed files with 17 additions and 7 deletions
+12
View File
@@ -74,6 +74,8 @@ class DatasetReader:
self.episodes = episodes self.episodes = episodes
self._tolerance_s = tolerance_s self._tolerance_s = tolerance_s
self._video_backend = video_backend self._video_backend = video_backend
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms self._image_transforms = image_transforms
self._return_uint8 = return_uint8 self._return_uint8 = return_uint8
@@ -86,6 +88,16 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms
def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations."""
self._image_transforms = None
def try_load(self) -> bool: def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient.""" """Attempt to load from local cache. Returns True if data is sufficient."""
try: try:
+5 -7
View File
@@ -201,8 +201,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self._requested_root = Path(root) if root else None self._requested_root = Path(root) if root else None
self.reader = None
self.set_image_transforms(image_transforms)
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
@@ -249,6 +247,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_transforms=image_transforms, image_transforms=image_transforms,
return_uint8=self._return_uint8, return_uint8=self._return_uint8,
) )
self.image_transforms = image_transforms
# Load actual data # Load actual data
if force_cache_sync or not self.reader.try_load(): if force_cache_sync or not self.reader.try_load():
@@ -505,15 +504,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
def set_image_transforms(self, image_transforms: Callable | None) -> None: def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations.""" """Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms): self._ensure_reader().set_image_transforms(image_transforms)
raise TypeError("image_transforms must be callable or None.")
self.image_transforms = image_transforms self.image_transforms = image_transforms
if self.reader is not None:
self.reader._image_transforms = image_transforms
def clear_image_transforms(self) -> None: def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations.""" """Remove the transform applied to visual observations."""
self.set_image_transforms(None) if self.reader is not None:
self.reader.set_image_transforms(None)
self.image_transforms = None
# ── Hub methods (stay on facade) ────────────────────────────────── # ── Hub methods (stay on facade) ──────────────────────────────────