From 6a8878a6391d8d1343c47632c831e10a8e7b2d54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=9B=9B=E4=B8=83?= <41624527+SevenFo@users.noreply.github.com> Date: Tue, 19 May 2026 22:53:19 +0800 Subject: [PATCH] fix(datasets): normalize shape=(1,) numeric values before HF encoding (#3344) * fix(datasets): normalize shape=(1,) numeric values before save * test(datasets): cover shape=(1,) int/bool and finalize Co-authored-by: Copilot --- src/lerobot/datasets/dataset_writer.py | 9 ++++++- tests/datasets/test_datasets.py | 36 ++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py index 6be63194f..633c00c1a 100644 --- a/src/lerobot/datasets/dataset_writer.py +++ b/src/lerobot/datasets/dataset_writer.py @@ -250,7 +250,14 @@ class DatasetWriter: for key, ft in self._meta.features.items(): if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: continue - episode_buffer[key] = np.stack(episode_buffer[key]) + stacked_values = np.stack(episode_buffer[key]) + + # `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars. + # Normalizing to `(N,)` keeps save semantics stable across dependency versions. + if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string": + stacked_values = stacked_values.reshape(episode_length) + + episode_buffer[key] = stacked_values # Wait for image writer to end, so that episode stats over images can be computed self._wait_image_writer() diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index ba9b64812..19c314fd6 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -24,6 +24,7 @@ import torch pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") +import datasets from huggingface_hub import HfApi from PIL import Image from safetensors.torch import load_file @@ -360,6 +361,41 @@ def test_add_frame_image_pil(image_dataset): assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) +@pytest.mark.parametrize( + "dtype,np_dtype,values,assert_fn", + [ + ("float32", np.float32, [1.0, 2.0], np.testing.assert_allclose), + ("int64", np.int64, [1, 2], np.testing.assert_array_equal), + ("bool", np.bool_, [True, False], np.testing.assert_array_equal), + ], + ids=["float32", "int64", "bool"], +) +def test_save_episode_shape_1_scalar_is_scalarized_before_hf_encoding( + tmp_path, empty_lerobot_dataset_factory, monkeypatch, dtype, np_dtype, values, assert_fn +): + features = {"state": {"dtype": dtype, "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": np.array([values[0]], dtype=np_dtype), "task": "Dummy task"}) + dataset.add_frame({"state": np.array([values[1]], dtype=np_dtype), "task": "Dummy task"}) + + captured = {} + original_from_dict = datasets.Dataset.from_dict + + def _from_dict_spy(cls, mapping, *args, **kwargs): + captured["state"] = mapping["state"] + return original_from_dict(mapping, *args, **kwargs) + + monkeypatch.setattr(datasets.Dataset, "from_dict", classmethod(_from_dict_spy)) + + dataset.save_episode() + dataset.finalize() + + assert "state" in captured + assert isinstance(captured["state"], np.ndarray) + assert captured["state"].shape == (2,) + assert_fn(captured["state"], np.array(values, dtype=np_dtype)) + + def test_set_image_transforms_applies_transparently(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})