diff --git a/src/lerobot/datasets/v21/convert_stats.py b/src/lerobot/datasets/v21/convert_stats.py index 462781c15..4effe9dd9 100644 --- a/src/lerobot/datasets/v21/convert_stats.py +++ b/src/lerobot/datasets/v21/convert_stats.py @@ -45,6 +45,8 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0 keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1 + if ft["dtype"] in ["image", "video"] and ep_ft_data.ndim == 3: + ep_ft_data = np.expand_dims(ep_ft_data, axis=0) ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims) if ft["dtype"] in ["image", "video"]: # remove batch dim