mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
fix(stats): handle scalar stats robustly
- Wrap cast_stats_to_numpy with np.atleast_1d to prevent 0-d arrays from scalar stats causing shape mismatches downstream.
This commit is contained in:
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
||||
Returns:
|
||||
dict: The statistics dictionary with values cast to numpy arrays.
|
||||
"""
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user