From 7c032f19fc29c7082c9862382d5abfd14621054e Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:59:11 +0200 Subject: [PATCH] feat(dataset): registering torchvision transforms (#3153) * add: a flexible transformation registry * fix: image transforms can be set both at init and after * add: tests * fix: take in review * feat(datasets): add image transform setters * fix: pre-commit * fix: CI --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --- src/lerobot/datasets/lerobot_dataset.py | 23 ++++++++-- src/lerobot/datasets/multi_dataset.py | 14 +++++- tests/datasets/test_datasets.py | 58 +++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 5 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index f719222fd..1725046f2 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -151,9 +151,11 @@ class LeRobotDataset(torch.utils.data.Dataset): ``$HF_LEROBOT_HOME/hub``. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. - image_transforms (Callable | None, optional): You can pass standard v2 image transforms from - torchvision.transforms.v2 here which will be applied to visual modalities (whether they come - from videos or images). Defaults to None. + image_transforms (Callable | None, optional): + Transform applied to visual modalities inside `__getitem__` after image decoding / tensor + conversion. This works for both image-backed and video-backed observations and can later be + updated with `set_image_transforms()` or cleared with `clear_image_transforms()`. + Defaults to None. delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None. tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in sync with the fps value. It is used at the init of the dataset to make sure that each @@ -192,7 +194,8 @@ class LeRobotDataset(torch.utils.data.Dataset): super().__init__() self.repo_id = repo_id self._requested_root = Path(root) if root else None - self.image_transforms = image_transforms + self.reader = None + self.set_image_transforms(image_transforms) self.delta_timestamps = delta_timestamps self.episodes = episodes self.tolerance_s = tolerance_s @@ -475,6 +478,18 @@ class LeRobotDataset(torch.utils.data.Dataset): f"}})" ) + def set_image_transforms(self, image_transforms: Callable | None) -> None: + """Replace the transform applied to visual observations.""" + if image_transforms is not None and not callable(image_transforms): + raise TypeError("image_transforms must be callable or None.") + self.image_transforms = image_transforms + if self.reader is not None: + self.reader._image_transforms = image_transforms + + def clear_image_transforms(self) -> None: + """Remove the transform applied to visual observations.""" + self.set_image_transforms(None) + # ── Hub methods (stay on facade) ────────────────────────────────── def push_to_hub( diff --git a/src/lerobot/datasets/multi_dataset.py b/src/lerobot/datasets/multi_dataset.py index d16c5bb07..092443077 100644 --- a/src/lerobot/datasets/multi_dataset.py +++ b/src/lerobot/datasets/multi_dataset.py @@ -89,12 +89,24 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): ) self.disabled_features.update(extra_keys) - self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps # TODO(rcadene, aliberts): We should not perform this aggregation for datasets # with multiple robots of different ranges. Instead we should have one normalization # per robot. self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) + self.set_image_transforms(image_transforms) + + def set_image_transforms(self, image_transforms: Callable | None) -> None: + """Replace the transform for this dataset and its children.""" + if image_transforms is not None and not callable(image_transforms): + raise TypeError("image_transforms must be callable or None.") + self.image_transforms = image_transforms + for dataset in getattr(self, "_datasets", []): + dataset.set_image_transforms(self.image_transforms) + + def clear_image_transforms(self) -> None: + """Remove the transform from this dataset and its children.""" + self.set_image_transforms(None) @property def repo_id_to_index(self): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index b2518149f..d4e9e88b8 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -24,6 +24,7 @@ import torch from huggingface_hub import HfApi from PIL import Image from safetensors.torch import load_file +from torchvision.transforms import v2 import lerobot from lerobot.configs.default import DatasetConfig @@ -34,6 +35,7 @@ from lerobot.datasets.image_writer import image_array_to_pil_image from lerobot.datasets.io_utils import hf_transform_to_torch from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.multi_dataset import MultiLeRobotDataset +from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -355,6 +357,62 @@ def test_add_frame_image_pil(image_dataset): assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) +def test_set_image_transforms_applies_transparently(image_dataset): + dataset = image_dataset + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + dataset.save_episode() + dataset.finalize() + + dataset.set_image_transforms(v2.Resize((224, 224))) + assert dataset[0]["image"].shape == torch.Size((3, 224, 224)) + + dataset.set_image_transforms(v2.Resize((128, 128))) + assert dataset[0]["image"].shape == torch.Size((3, 128, 128)) + + dataset.clear_image_transforms() + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + + +def test_set_image_transforms_supports_lerobot_image_transforms(image_dataset): + dataset = image_dataset + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + dataset.save_episode() + dataset.finalize() + + image_transforms = ImageTransforms(ImageTransformsConfig(enable=False)) + dataset.set_image_transforms(image_transforms) + + assert dataset.image_transforms is image_transforms + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + + +def test_set_image_transforms_supports_loaded_dataset(tmp_path, lerobot_dataset_factory): + dataset = lerobot_dataset_factory(root=tmp_path / "test", use_videos=False) + dataset.set_image_transforms(v2.Compose([v2.Resize((224, 224)), v2.Resize((112, 112))])) + + camera_key = dataset.meta.camera_keys[0] + assert dataset[0][camera_key].shape == torch.Size((3, 112, 112)) + + +def test_multilerobot_dataset_set_image_transforms_propagates(tmp_path, lerobot_dataset_factory): + root = tmp_path / "multi" + repo_ids = ["lerobot/test_multi_a", "lerobot/test_multi_b"] + + for repo_id in repo_ids: + lerobot_dataset_factory(root=root / repo_id, repo_id=repo_id, use_videos=False) + + dataset = MultiLeRobotDataset(repo_ids, root=root, download_videos=False) + dataset.set_image_transforms(v2.Resize((96, 96))) + + camera_key = dataset.camera_keys[0] + assert dataset[0][camera_key].shape == torch.Size((3, 96, 96)) + assert all(child.image_transforms is dataset.image_transforms for child in dataset._datasets) + + dataset.clear_image_transforms() + assert dataset.image_transforms is None + assert all(child.image_transforms is None for child in dataset._datasets) + + def test_image_array_to_pil_image_wrong_range_float_0_255(): image = np.random.rand(*DUMMY_HWC) * 255 with pytest.raises(ValueError):