diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py index a41f34704..3af32a1a1 100644 --- a/src/lerobot/datasets/io_utils.py +++ b/src/lerobot/datasets/io_utils.py @@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: Returns: dict: The statistics dictionary with values cast to numpy arrays. """ - stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} + stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats)