mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
fix(datasets): normalize shape=(1,) numeric values before HF encoding (#3344)
* fix(datasets): normalize shape=(1,) numeric values before save * test(datasets): cover shape=(1,) int/bool and finalize Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -250,7 +250,14 @@ class DatasetWriter:
|
||||
for key, ft in self._meta.features.items():
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
stacked_values = np.stack(episode_buffer[key])
|
||||
|
||||
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
|
||||
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
|
||||
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
|
||||
stacked_values = stacked_values.reshape(episode_length)
|
||||
|
||||
episode_buffer[key] = stacked_values
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import datasets
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
@@ -360,6 +361,41 @@ def test_add_frame_image_pil(image_dataset):
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,np_dtype,values,assert_fn",
|
||||
[
|
||||
("float32", np.float32, [1.0, 2.0], np.testing.assert_allclose),
|
||||
("int64", np.int64, [1, 2], np.testing.assert_array_equal),
|
||||
("bool", np.bool_, [True, False], np.testing.assert_array_equal),
|
||||
],
|
||||
ids=["float32", "int64", "bool"],
|
||||
)
|
||||
def test_save_episode_shape_1_scalar_is_scalarized_before_hf_encoding(
|
||||
tmp_path, empty_lerobot_dataset_factory, monkeypatch, dtype, np_dtype, values, assert_fn
|
||||
):
|
||||
features = {"state": {"dtype": dtype, "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": np.array([values[0]], dtype=np_dtype), "task": "Dummy task"})
|
||||
dataset.add_frame({"state": np.array([values[1]], dtype=np_dtype), "task": "Dummy task"})
|
||||
|
||||
captured = {}
|
||||
original_from_dict = datasets.Dataset.from_dict
|
||||
|
||||
def _from_dict_spy(cls, mapping, *args, **kwargs):
|
||||
captured["state"] = mapping["state"]
|
||||
return original_from_dict(mapping, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(datasets.Dataset, "from_dict", classmethod(_from_dict_spy))
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert "state" in captured
|
||||
assert isinstance(captured["state"], np.ndarray)
|
||||
assert captured["state"].shape == (2,)
|
||||
assert_fn(captured["state"], np.array(values, dtype=np_dtype))
|
||||
|
||||
|
||||
def test_set_image_transforms_applies_transparently(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
|
||||
Reference in New Issue
Block a user