mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2c78c46fcd |
@@ -72,6 +72,8 @@ class DatasetReader:
|
|||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self._tolerance_s = tolerance_s
|
self._tolerance_s = tolerance_s
|
||||||
self._video_backend = video_backend
|
self._video_backend = video_backend
|
||||||
|
if image_transforms is not None and not callable(image_transforms):
|
||||||
|
raise TypeError("image_transforms must be callable or None.")
|
||||||
self._image_transforms = image_transforms
|
self._image_transforms = image_transforms
|
||||||
|
|
||||||
self.hf_dataset: datasets.Dataset | None = None
|
self.hf_dataset: datasets.Dataset | None = None
|
||||||
@@ -83,6 +85,16 @@ class DatasetReader:
|
|||||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||||
|
|
||||||
|
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||||
|
"""Replace the transform applied to visual observations."""
|
||||||
|
if image_transforms is not None and not callable(image_transforms):
|
||||||
|
raise TypeError("image_transforms must be callable or None.")
|
||||||
|
self._image_transforms = image_transforms
|
||||||
|
|
||||||
|
def clear_image_transforms(self) -> None:
|
||||||
|
"""Remove the transform applied to visual observations."""
|
||||||
|
self._image_transforms = None
|
||||||
|
|
||||||
def try_load(self) -> bool:
|
def try_load(self) -> bool:
|
||||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -194,8 +194,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self._requested_root = Path(root) if root else None
|
self._requested_root = Path(root) if root else None
|
||||||
self.reader = None
|
|
||||||
self.set_image_transforms(image_transforms)
|
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
@@ -225,6 +223,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
)
|
)
|
||||||
|
self.image_transforms = image_transforms
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
if force_cache_sync or not self.reader.try_load():
|
if force_cache_sync or not self.reader.try_load():
|
||||||
@@ -480,15 +479,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||||
"""Replace the transform applied to visual observations."""
|
"""Replace the transform applied to visual observations."""
|
||||||
if image_transforms is not None and not callable(image_transforms):
|
self._ensure_reader().set_image_transforms(image_transforms)
|
||||||
raise TypeError("image_transforms must be callable or None.")
|
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
if self.reader is not None:
|
|
||||||
self.reader._image_transforms = image_transforms
|
|
||||||
|
|
||||||
def clear_image_transforms(self) -> None:
|
def clear_image_transforms(self) -> None:
|
||||||
"""Remove the transform applied to visual observations."""
|
"""Remove the transform applied to visual observations."""
|
||||||
self.set_image_transforms(None)
|
if self.reader is not None:
|
||||||
|
self.reader.set_image_transforms(None)
|
||||||
|
self.image_transforms = None
|
||||||
|
|
||||||
# ── Hub methods (stay on facade) ──────────────────────────────────
|
# ── Hub methods (stay on facade) ──────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user