diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 39a1b6d2b..b496e4f65 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -15,6 +15,7 @@ # limitations under the License. import contextlib from collections.abc import Callable +from copy import deepcopy from pathlib import Path import numpy as np @@ -709,7 +710,7 @@ class LeRobotDatasetMetadata: obj.root.mkdir(parents=True, exist_ok=False) - features = {**features, **DEFAULT_FEATURES} + features = {**deepcopy(features), **DEFAULT_FEATURES} _validate_feature_names(features) obj.tasks = None diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 91dc66af2..9aca859b4 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -27,6 +27,7 @@ import logging import shutil from collections.abc import Callable from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed +from copy import deepcopy from pathlib import Path import datasets @@ -1101,7 +1102,9 @@ def _copy_episodes_metadata_and_stats( if dst_meta.video_keys and src_dataset.meta.video_keys: for key in dst_meta.video_keys: if key in src_dataset.meta.features: - dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {}) + dst_meta.info.features[key]["info"] = deepcopy( + src_dataset.meta.info.features[key].get("info", {}) + ) write_info(dst_meta.info, dst_meta.root) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 19c314fd6..1d2fb1d55 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config from lerobot.transforms import ImageTransforms, ImageTransformsConfig from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD from lerobot.utils.feature_utils import hw_to_dataset_features -from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID +from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig from tests.utils import require_x86_64_kernel @@ -133,6 +133,21 @@ def test_dataset_feature_with_forward_slash_raises_error(): ) +def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory): + # ``create`` must deep-copy features so a dataset built from another's features stays independent. + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False + ) + dataset_copy = empty_lerobot_dataset_factory( + root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False + ) + + original_shape = dataset.meta.info.features["state"]["shape"] + dataset_copy.meta.info.features["state"]["shape"] = (999,) + + assert dataset.meta.info.features["state"]["shape"] == original_shape + + def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)