mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
feat(dataset): registering torchvision transforms (#3153)
* add: a flexible transformation registry * fix: image transforms can be set both at init and after * add: tests * fix: take in review * feat(datasets): add image transform setters * fix: pre-commit * fix: CI --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e2f27bf71b
commit
7c032f19fc
@@ -151,9 +151,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
``$HF_LEROBOT_HOME/hub``.
|
``$HF_LEROBOT_HOME/hub``.
|
||||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||||
their episode_index in this list. Defaults to None.
|
their episode_index in this list. Defaults to None.
|
||||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
image_transforms (Callable | None, optional):
|
||||||
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
|
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||||
from videos or images). Defaults to None.
|
conversion. This works for both image-backed and video-backed observations and can later be
|
||||||
|
updated with `set_image_transforms()` or cleared with `clear_image_transforms()`.
|
||||||
|
Defaults to None.
|
||||||
delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
|
delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
|
||||||
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
|
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
|
||||||
sync with the fps value. It is used at the init of the dataset to make sure that each
|
sync with the fps value. It is used at the init of the dataset to make sure that each
|
||||||
@@ -192,7 +194,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self._requested_root = Path(root) if root else None
|
self._requested_root = Path(root) if root else None
|
||||||
self.image_transforms = image_transforms
|
self.reader = None
|
||||||
|
self.set_image_transforms(image_transforms)
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
@@ -475,6 +478,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
f"}})"
|
f"}})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||||
|
"""Replace the transform applied to visual observations."""
|
||||||
|
if image_transforms is not None and not callable(image_transforms):
|
||||||
|
raise TypeError("image_transforms must be callable or None.")
|
||||||
|
self.image_transforms = image_transforms
|
||||||
|
if self.reader is not None:
|
||||||
|
self.reader._image_transforms = image_transforms
|
||||||
|
|
||||||
|
def clear_image_transforms(self) -> None:
|
||||||
|
"""Remove the transform applied to visual observations."""
|
||||||
|
self.set_image_transforms(None)
|
||||||
|
|
||||||
# ── Hub methods (stay on facade) ──────────────────────────────────
|
# ── Hub methods (stay on facade) ──────────────────────────────────
|
||||||
|
|
||||||
def push_to_hub(
|
def push_to_hub(
|
||||||
|
|||||||
@@ -89,12 +89,24 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
self.disabled_features.update(extra_keys)
|
self.disabled_features.update(extra_keys)
|
||||||
|
|
||||||
self.image_transforms = image_transforms
|
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||||
# with multiple robots of different ranges. Instead we should have one normalization
|
# with multiple robots of different ranges. Instead we should have one normalization
|
||||||
# per robot.
|
# per robot.
|
||||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||||
|
self.set_image_transforms(image_transforms)
|
||||||
|
|
||||||
|
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||||
|
"""Replace the transform for this dataset and its children."""
|
||||||
|
if image_transforms is not None and not callable(image_transforms):
|
||||||
|
raise TypeError("image_transforms must be callable or None.")
|
||||||
|
self.image_transforms = image_transforms
|
||||||
|
for dataset in getattr(self, "_datasets", []):
|
||||||
|
dataset.set_image_transforms(self.image_transforms)
|
||||||
|
|
||||||
|
def clear_image_transforms(self) -> None:
|
||||||
|
"""Remove the transform from this dataset and its children."""
|
||||||
|
self.set_image_transforms(None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repo_id_to_index(self):
|
def repo_id_to_index(self):
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import torch
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
@@ -34,6 +35,7 @@ from lerobot.datasets.image_writer import image_array_to_pil_image
|
|||||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||||
|
from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
@@ -355,6 +357,62 @@ def test_add_frame_image_pil(image_dataset):
|
|||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_image_transforms_applies_transparently(image_dataset):
|
||||||
|
dataset = image_dataset
|
||||||
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||||
|
dataset.save_episode()
|
||||||
|
dataset.finalize()
|
||||||
|
|
||||||
|
dataset.set_image_transforms(v2.Resize((224, 224)))
|
||||||
|
assert dataset[0]["image"].shape == torch.Size((3, 224, 224))
|
||||||
|
|
||||||
|
dataset.set_image_transforms(v2.Resize((128, 128)))
|
||||||
|
assert dataset[0]["image"].shape == torch.Size((3, 128, 128))
|
||||||
|
|
||||||
|
dataset.clear_image_transforms()
|
||||||
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_image_transforms_supports_lerobot_image_transforms(image_dataset):
|
||||||
|
dataset = image_dataset
|
||||||
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||||
|
dataset.save_episode()
|
||||||
|
dataset.finalize()
|
||||||
|
|
||||||
|
image_transforms = ImageTransforms(ImageTransformsConfig(enable=False))
|
||||||
|
dataset.set_image_transforms(image_transforms)
|
||||||
|
|
||||||
|
assert dataset.image_transforms is image_transforms
|
||||||
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_image_transforms_supports_loaded_dataset(tmp_path, lerobot_dataset_factory):
|
||||||
|
dataset = lerobot_dataset_factory(root=tmp_path / "test", use_videos=False)
|
||||||
|
dataset.set_image_transforms(v2.Compose([v2.Resize((224, 224)), v2.Resize((112, 112))]))
|
||||||
|
|
||||||
|
camera_key = dataset.meta.camera_keys[0]
|
||||||
|
assert dataset[0][camera_key].shape == torch.Size((3, 112, 112))
|
||||||
|
|
||||||
|
|
||||||
|
def test_multilerobot_dataset_set_image_transforms_propagates(tmp_path, lerobot_dataset_factory):
|
||||||
|
root = tmp_path / "multi"
|
||||||
|
repo_ids = ["lerobot/test_multi_a", "lerobot/test_multi_b"]
|
||||||
|
|
||||||
|
for repo_id in repo_ids:
|
||||||
|
lerobot_dataset_factory(root=root / repo_id, repo_id=repo_id, use_videos=False)
|
||||||
|
|
||||||
|
dataset = MultiLeRobotDataset(repo_ids, root=root, download_videos=False)
|
||||||
|
dataset.set_image_transforms(v2.Resize((96, 96)))
|
||||||
|
|
||||||
|
camera_key = dataset.camera_keys[0]
|
||||||
|
assert dataset[0][camera_key].shape == torch.Size((3, 96, 96))
|
||||||
|
assert all(child.image_transforms is dataset.image_transforms for child in dataset._datasets)
|
||||||
|
|
||||||
|
dataset.clear_image_transforms()
|
||||||
|
assert dataset.image_transforms is None
|
||||||
|
assert all(child.image_transforms is None for child in dataset._datasets)
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_pil_image_wrong_range_float_0_255():
|
def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||||
image = np.random.rand(*DUMMY_HWC) * 255
|
image = np.random.rand(*DUMMY_HWC) * 255
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
|||||||
Reference in New Issue
Block a user