feat(dataset): registering torchvision transforms (#3153)

* add: a flexible transformation registry

* fix: image transforms can be set both at init and after

* add: tests

* fix: take in review

* feat(datasets): add image transform setters

* fix: pre-commit

* fix: CI

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
Francesco Capuano
2026-04-07 15:59:11 +02:00
committed by GitHub
parent e2f27bf71b
commit 7c032f19fc
3 changed files with 90 additions and 5 deletions
+58
View File
@@ -24,6 +24,7 @@ import torch
from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import v2
import lerobot
from lerobot.configs.default import DatasetConfig
@@ -34,6 +35,7 @@ from lerobot.datasets.image_writer import image_array_to_pil_image
from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -355,6 +357,62 @@ def test_add_frame_image_pil(image_dataset):
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_set_image_transforms_applies_transparently(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
dataset.set_image_transforms(v2.Resize((224, 224)))
assert dataset[0]["image"].shape == torch.Size((3, 224, 224))
dataset.set_image_transforms(v2.Resize((128, 128)))
assert dataset[0]["image"].shape == torch.Size((3, 128, 128))
dataset.clear_image_transforms()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_set_image_transforms_supports_lerobot_image_transforms(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
image_transforms = ImageTransforms(ImageTransformsConfig(enable=False))
dataset.set_image_transforms(image_transforms)
assert dataset.image_transforms is image_transforms
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_set_image_transforms_supports_loaded_dataset(tmp_path, lerobot_dataset_factory):
dataset = lerobot_dataset_factory(root=tmp_path / "test", use_videos=False)
dataset.set_image_transforms(v2.Compose([v2.Resize((224, 224)), v2.Resize((112, 112))]))
camera_key = dataset.meta.camera_keys[0]
assert dataset[0][camera_key].shape == torch.Size((3, 112, 112))
def test_multilerobot_dataset_set_image_transforms_propagates(tmp_path, lerobot_dataset_factory):
root = tmp_path / "multi"
repo_ids = ["lerobot/test_multi_a", "lerobot/test_multi_b"]
for repo_id in repo_ids:
lerobot_dataset_factory(root=root / repo_id, repo_id=repo_id, use_videos=False)
dataset = MultiLeRobotDataset(repo_ids, root=root, download_videos=False)
dataset.set_image_transforms(v2.Resize((96, 96)))
camera_key = dataset.camera_keys[0]
assert dataset[0][camera_key].shape == torch.Size((3, 96, 96))
assert all(child.image_transforms is dataset.image_transforms for child in dataset._datasets)
dataset.clear_image_transforms()
assert dataset.image_transforms is None
assert all(child.image_transforms is None for child in dataset._datasets)
def test_image_array_to_pil_image_wrong_range_float_0_255():
image = np.random.rand(*DUMMY_HWC) * 255
with pytest.raises(ValueError):