mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
fix(datasets): normalize shape=(1,) numeric values before HF encoding (#3344)
* fix(datasets): normalize shape=(1,) numeric values before save * test(datasets): cover shape=(1,) int/bool and finalize Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -250,7 +250,14 @@ class DatasetWriter:
|
|||||||
for key, ft in self._meta.features.items():
|
for key, ft in self._meta.features.items():
|
||||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||||
continue
|
continue
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
stacked_values = np.stack(episode_buffer[key])
|
||||||
|
|
||||||
|
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
|
||||||
|
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
|
||||||
|
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
|
||||||
|
stacked_values = stacked_values.reshape(episode_length)
|
||||||
|
|
||||||
|
episode_buffer[key] = stacked_values
|
||||||
|
|
||||||
# Wait for image writer to end, so that episode stats over images can be computed
|
# Wait for image writer to end, so that episode stats over images can be computed
|
||||||
self._wait_image_writer()
|
self._wait_image_writer()
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import torch
|
|||||||
|
|
||||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||||
|
|
||||||
|
import datasets
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
@@ -360,6 +361,41 @@ def test_add_frame_image_pil(image_dataset):
|
|||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dtype,np_dtype,values,assert_fn",
|
||||||
|
[
|
||||||
|
("float32", np.float32, [1.0, 2.0], np.testing.assert_allclose),
|
||||||
|
("int64", np.int64, [1, 2], np.testing.assert_array_equal),
|
||||||
|
("bool", np.bool_, [True, False], np.testing.assert_array_equal),
|
||||||
|
],
|
||||||
|
ids=["float32", "int64", "bool"],
|
||||||
|
)
|
||||||
|
def test_save_episode_shape_1_scalar_is_scalarized_before_hf_encoding(
|
||||||
|
tmp_path, empty_lerobot_dataset_factory, monkeypatch, dtype, np_dtype, values, assert_fn
|
||||||
|
):
|
||||||
|
features = {"state": {"dtype": dtype, "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
dataset.add_frame({"state": np.array([values[0]], dtype=np_dtype), "task": "Dummy task"})
|
||||||
|
dataset.add_frame({"state": np.array([values[1]], dtype=np_dtype), "task": "Dummy task"})
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
original_from_dict = datasets.Dataset.from_dict
|
||||||
|
|
||||||
|
def _from_dict_spy(cls, mapping, *args, **kwargs):
|
||||||
|
captured["state"] = mapping["state"]
|
||||||
|
return original_from_dict(mapping, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(datasets.Dataset, "from_dict", classmethod(_from_dict_spy))
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
dataset.finalize()
|
||||||
|
|
||||||
|
assert "state" in captured
|
||||||
|
assert isinstance(captured["state"], np.ndarray)
|
||||||
|
assert captured["state"].shape == (2,)
|
||||||
|
assert_fn(captured["state"], np.array(values, dtype=np_dtype))
|
||||||
|
|
||||||
|
|
||||||
def test_set_image_transforms_applies_transparently(image_dataset):
|
def test_set_image_transforms_applies_transparently(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||||
|
|||||||
Reference in New Issue
Block a user