feat(datasets): improve image transform support (#2885)

* improve image transform support

* add tests

* Add stricter transform check and extra test

* improve subclass check
This commit is contained in:
Reece O'Mahoney
2026-02-05 14:39:58 +00:00
committed by GitHub
parent 0f39248445
commit 97e7e0f9ed
2 changed files with 34 additions and 9 deletions
+10 -9
View File
@@ -216,16 +216,17 @@ class ImageTransformsConfig:
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
if cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
transform_cls = getattr(v2, cfg.type, None)
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
return transform_cls(**cfg.kwargs)
raise ValueError(
f"Transform '{cfg.type}' is not valid. It must be a class in "
f"torchvision.transforms.v2 or 'SharpnessJitter'."
)
class ImageTransforms(Transform):
+24
View File
@@ -390,6 +390,30 @@ def test_sharpness_jitter_invalid_range_max_smaller():
SharpnessJitter((2.0, 0.1))
def test_make_transform_from_config_with_v2_resize(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)})
tf = make_transform_from_config(tf_cfg)
assert isinstance(tf, v2.Resize)
output = tf(img_tensor)
assert output.shape[-2:] == (32, 32)
def test_make_transform_from_config_with_v2_identity(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformConfig(type="Identity", kwargs={})
tf = make_transform_from_config(tf_cfg)
assert isinstance(tf, v2.Identity)
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_make_transform_from_config_invalid_type():
tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={})
with pytest.raises(ValueError, match="not valid"):
make_transform_from_config(tf_cfg)
def test_save_all_transforms(img_tensor_factory, tmp_path):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(enable=True)