mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
fix(datasets): normalize shape=(1,) numeric values before HF encoding (#3344)
* fix(datasets): normalize shape=(1,) numeric values before save * test(datasets): cover shape=(1,) int/bool and finalize Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -24,6 +24,7 @@ import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import datasets
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
@@ -360,6 +361,41 @@ def test_add_frame_image_pil(image_dataset):
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,np_dtype,values,assert_fn",
|
||||
[
|
||||
("float32", np.float32, [1.0, 2.0], np.testing.assert_allclose),
|
||||
("int64", np.int64, [1, 2], np.testing.assert_array_equal),
|
||||
("bool", np.bool_, [True, False], np.testing.assert_array_equal),
|
||||
],
|
||||
ids=["float32", "int64", "bool"],
|
||||
)
|
||||
def test_save_episode_shape_1_scalar_is_scalarized_before_hf_encoding(
|
||||
tmp_path, empty_lerobot_dataset_factory, monkeypatch, dtype, np_dtype, values, assert_fn
|
||||
):
|
||||
features = {"state": {"dtype": dtype, "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": np.array([values[0]], dtype=np_dtype), "task": "Dummy task"})
|
||||
dataset.add_frame({"state": np.array([values[1]], dtype=np_dtype), "task": "Dummy task"})
|
||||
|
||||
captured = {}
|
||||
original_from_dict = datasets.Dataset.from_dict
|
||||
|
||||
def _from_dict_spy(cls, mapping, *args, **kwargs):
|
||||
captured["state"] = mapping["state"]
|
||||
return original_from_dict(mapping, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(datasets.Dataset, "from_dict", classmethod(_from_dict_spy))
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert "state" in captured
|
||||
assert isinstance(captured["state"], np.ndarray)
|
||||
assert captured["state"].shape == (2,)
|
||||
assert_fn(captured["state"], np.array(values, dtype=np_dtype))
|
||||
|
||||
|
||||
def test_set_image_transforms_applies_transparently(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
|
||||
Reference in New Issue
Block a user