mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
feat(dataset): registering torchvision transforms (#3153)
* add: a flexible transformation registry * fix: image transforms can be set both at init and after * add: tests * fix: take in review * feat(datasets): add image transform setters * fix: pre-commit * fix: CI --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e2f27bf71b
commit
7c032f19fc
@@ -24,6 +24,7 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
@@ -34,6 +35,7 @@ from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -355,6 +357,62 @@ def test_add_frame_image_pil(image_dataset):
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
||||
def test_set_image_transforms_applies_transparently(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
dataset.set_image_transforms(v2.Resize((224, 224)))
|
||||
assert dataset[0]["image"].shape == torch.Size((3, 224, 224))
|
||||
|
||||
dataset.set_image_transforms(v2.Resize((128, 128)))
|
||||
assert dataset[0]["image"].shape == torch.Size((3, 128, 128))
|
||||
|
||||
dataset.clear_image_transforms()
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
||||
def test_set_image_transforms_supports_lerobot_image_transforms(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
image_transforms = ImageTransforms(ImageTransformsConfig(enable=False))
|
||||
dataset.set_image_transforms(image_transforms)
|
||||
|
||||
assert dataset.image_transforms is image_transforms
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
||||
def test_set_image_transforms_supports_loaded_dataset(tmp_path, lerobot_dataset_factory):
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", use_videos=False)
|
||||
dataset.set_image_transforms(v2.Compose([v2.Resize((224, 224)), v2.Resize((112, 112))]))
|
||||
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
assert dataset[0][camera_key].shape == torch.Size((3, 112, 112))
|
||||
|
||||
|
||||
def test_multilerobot_dataset_set_image_transforms_propagates(tmp_path, lerobot_dataset_factory):
|
||||
root = tmp_path / "multi"
|
||||
repo_ids = ["lerobot/test_multi_a", "lerobot/test_multi_b"]
|
||||
|
||||
for repo_id in repo_ids:
|
||||
lerobot_dataset_factory(root=root / repo_id, repo_id=repo_id, use_videos=False)
|
||||
|
||||
dataset = MultiLeRobotDataset(repo_ids, root=root, download_videos=False)
|
||||
dataset.set_image_transforms(v2.Resize((96, 96)))
|
||||
|
||||
camera_key = dataset.camera_keys[0]
|
||||
assert dataset[0][camera_key].shape == torch.Size((3, 96, 96))
|
||||
assert all(child.image_transforms is dataset.image_transforms for child in dataset._datasets)
|
||||
|
||||
dataset.clear_image_transforms()
|
||||
assert dataset.image_transforms is None
|
||||
assert all(child.image_transforms is None for child in dataset._datasets)
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||
image = np.random.rand(*DUMMY_HWC) * 255
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
Reference in New Issue
Block a user