feat(dataset): registering torchvision transforms (#3153)

* add: a flexible transformation registry

* fix: image transforms can be set both at init and after

* add: tests

* fix: take in review

* feat(datasets): add image transform setters

* fix: pre-commit

* fix: CI

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
Francesco Capuano
2026-04-07 15:59:11 +02:00
committed by GitHub
parent e2f27bf71b
commit 7c032f19fc
3 changed files with 90 additions and 5 deletions
+19 -4
View File
@@ -151,9 +151,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
from videos or images). Defaults to None.
image_transforms (Callable | None, optional):
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
conversion. This works for both image-backed and video-backed observations and can later be
updated with `set_image_transforms()` or cleared with `clear_image_transforms()`.
Defaults to None.
delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
sync with the fps value. It is used at the init of the dataset to make sure that each
@@ -192,7 +194,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_id = repo_id
self._requested_root = Path(root) if root else None
self.image_transforms = image_transforms
self.reader = None
self.set_image_transforms(image_transforms)
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
@@ -475,6 +478,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
f"}})"
)
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self.image_transforms = image_transforms
if self.reader is not None:
self.reader._image_transforms = image_transforms
def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations."""
self.set_image_transforms(None)
# ── Hub methods (stay on facade) ──────────────────────────────────
def push_to_hub(
+13 -1
View File
@@ -89,12 +89,24 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
)
self.disabled_features.update(extra_keys)
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
self.set_image_transforms(image_transforms)
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform for this dataset and its children."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self.image_transforms = image_transforms
for dataset in getattr(self, "_datasets", []):
dataset.set_image_transforms(self.image_transforms)
def clear_image_transforms(self) -> None:
"""Remove the transform from this dataset and its children."""
self.set_image_transforms(None)
@property
def repo_id_to_index(self):