fix(datasets)

This commit is contained in:
CarolinePascal
2026-03-07 01:14:19 +01:00
parent 5e74f06b20
commit 10c2e2fc87
@@ -45,6 +45,8 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
if ft["dtype"] in ["image", "video"] and ep_ft_data.ndim == 3:
ep_ft_data = np.expand_dims(ep_ft_data, axis=0)
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
if ft["dtype"] in ["image", "video"]: # remove batch dim