Compare commits

...

1 Commits

Author SHA1 Message Date
CarolinePascal 2c78c46fcd fix(image transforms): cleaning up image_transforms implementation in LeRobotDataset 2026-04-07 22:15:34 +02:00
2 changed files with 17 additions and 7 deletions
+12
View File
@@ -72,6 +72,8 @@ class DatasetReader:
self.episodes = episodes
self._tolerance_s = tolerance_s
self._video_backend = video_backend
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms
self.hf_dataset: datasets.Dataset | None = None
@@ -83,6 +85,16 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms
def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations."""
self._image_transforms = None
def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient."""
try:
+5 -7
View File
@@ -194,8 +194,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_id = repo_id
self._requested_root = Path(root) if root else None
self.reader = None
self.set_image_transforms(image_transforms)
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
@@ -225,6 +223,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
)
self.image_transforms = image_transforms
# Load actual data
if force_cache_sync or not self.reader.try_load():
@@ -480,15 +479,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._ensure_reader().set_image_transforms(image_transforms)
self.image_transforms = image_transforms
if self.reader is not None:
self.reader._image_transforms = image_transforms
def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations."""
self.set_image_transforms(None)
if self.reader is not None:
self.reader.set_image_transforms(None)
self.image_transforms = None
# ── Hub methods (stay on facade) ──────────────────────────────────