mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
feat(datasets): improve image transform support (#2885)
* improve image transform support * add tests * Add stricter transform check and extra test * improve subclass check
This commit is contained in:
@@ -216,16 +216,17 @@ class ImageTransformsConfig:
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
if cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
elif cfg.type == "RandomAffine":
|
||||
return v2.RandomAffine(**cfg.kwargs)
|
||||
else:
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
transform_cls = getattr(v2, cfg.type, None)
|
||||
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
|
||||
return transform_cls(**cfg.kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Transform '{cfg.type}' is not valid. It must be a class in "
|
||||
f"torchvision.transforms.v2 or 'SharpnessJitter'."
|
||||
)
|
||||
|
||||
|
||||
class ImageTransforms(Transform):
|
||||
|
||||
Reference in New Issue
Block a user