mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
feat(datasets): improve image transform support (#2885)
* improve image transform support * add tests * Add stricter transform check and extra test * improve subclass check
This commit is contained in:
@@ -216,16 +216,17 @@ class ImageTransformsConfig:
|
|||||||
|
|
||||||
|
|
||||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||||
if cfg.type == "Identity":
|
if cfg.type == "SharpnessJitter":
|
||||||
return v2.Identity(**cfg.kwargs)
|
|
||||||
elif cfg.type == "ColorJitter":
|
|
||||||
return v2.ColorJitter(**cfg.kwargs)
|
|
||||||
elif cfg.type == "SharpnessJitter":
|
|
||||||
return SharpnessJitter(**cfg.kwargs)
|
return SharpnessJitter(**cfg.kwargs)
|
||||||
elif cfg.type == "RandomAffine":
|
|
||||||
return v2.RandomAffine(**cfg.kwargs)
|
transform_cls = getattr(v2, cfg.type, None)
|
||||||
else:
|
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
|
||||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
return transform_cls(**cfg.kwargs)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Transform '{cfg.type}' is not valid. It must be a class in "
|
||||||
|
f"torchvision.transforms.v2 or 'SharpnessJitter'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageTransforms(Transform):
|
class ImageTransforms(Transform):
|
||||||
|
|||||||
@@ -390,6 +390,30 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
|||||||
SharpnessJitter((2.0, 0.1))
|
SharpnessJitter((2.0, 0.1))
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_transform_from_config_with_v2_resize(img_tensor_factory):
|
||||||
|
img_tensor = img_tensor_factory()
|
||||||
|
tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)})
|
||||||
|
tf = make_transform_from_config(tf_cfg)
|
||||||
|
assert isinstance(tf, v2.Resize)
|
||||||
|
output = tf(img_tensor)
|
||||||
|
assert output.shape[-2:] == (32, 32)
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_transform_from_config_with_v2_identity(img_tensor_factory):
|
||||||
|
img_tensor = img_tensor_factory()
|
||||||
|
tf_cfg = ImageTransformConfig(type="Identity", kwargs={})
|
||||||
|
tf = make_transform_from_config(tf_cfg)
|
||||||
|
assert isinstance(tf, v2.Identity)
|
||||||
|
output = tf(img_tensor)
|
||||||
|
assert output.shape == img_tensor.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_transform_from_config_invalid_type():
|
||||||
|
tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={})
|
||||||
|
with pytest.raises(ValueError, match="not valid"):
|
||||||
|
make_transform_from_config(tf_cfg)
|
||||||
|
|
||||||
|
|
||||||
def test_save_all_transforms(img_tensor_factory, tmp_path):
|
def test_save_all_transforms(img_tensor_factory, tmp_path):
|
||||||
img_tensor = img_tensor_factory()
|
img_tensor = img_tensor_factory()
|
||||||
tf_cfg = ImageTransformsConfig(enable=True)
|
tf_cfg = ImageTransformsConfig(enable=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user