mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
fix(image transforms): cleaning up image_transforms implementation in LeRobotDataset (#3829)
This commit is contained in:
@@ -74,6 +74,8 @@ class DatasetReader:
|
||||
self.episodes = episodes
|
||||
self._tolerance_s = tolerance_s
|
||||
self._video_backend = video_backend
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self._image_transforms = image_transforms
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
@@ -86,6 +88,16 @@ class DatasetReader:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||
"""Replace the transform applied to visual observations."""
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self._image_transforms = image_transforms
|
||||
|
||||
def clear_image_transforms(self) -> None:
|
||||
"""Remove the transform applied to visual observations."""
|
||||
self._image_transforms = None
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
try:
|
||||
|
||||
@@ -201,8 +201,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self._requested_root = Path(root) if root else None
|
||||
self.reader = None
|
||||
self.set_image_transforms(image_transforms)
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
@@ -249,6 +247,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_transforms=image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
self.image_transforms = image_transforms
|
||||
|
||||
# Load actual data
|
||||
if force_cache_sync or not self.reader.try_load():
|
||||
@@ -505,15 +504,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||
"""Replace the transform applied to visual observations."""
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self._ensure_reader().set_image_transforms(image_transforms)
|
||||
self.image_transforms = image_transforms
|
||||
if self.reader is not None:
|
||||
self.reader._image_transforms = image_transforms
|
||||
|
||||
def clear_image_transforms(self) -> None:
|
||||
"""Remove the transform applied to visual observations."""
|
||||
self.set_image_transforms(None)
|
||||
if self.reader is not None:
|
||||
self.reader.set_image_transforms(None)
|
||||
self.image_transforms = None
|
||||
|
||||
# ── Hub methods (stay on facade) ──────────────────────────────────
|
||||
|
||||
|
||||
Reference in New Issue
Block a user