From 5c187a06f5d249bc4cf9d954d4d0382ef176eef7 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Wed, 24 Jun 2026 17:33:07 +0200 Subject: [PATCH] fix(copy&reindex): fixing metadat reshaping for single channel frames --- src/lerobot/datasets/dataset_tools.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index c6f2b7f0c..4e6507240 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -883,11 +883,11 @@ def _copy_and_reindex_episodes_metadata( episode_meta.update(video_metadata[new_idx]) # Extract episode statistics from parquet metadata. - # Note (maractingi): When pandas/pyarrow serializes numpy arrays with shape (3, 1, 1) to parquet, + # When pandas/pyarrow serializes numpy arrays with shape (C, 1, 1) to parquet, # they are being deserialized as nested object arrays like: # array([array([array([0.])]), array([array([0.])]), array([array([0.])])]) # This happens particularly with image/video statistics. We need to detect and flatten - # these nested structures back to proper (3, 1, 1) arrays so aggregate_stats can process them. + # these nested structures back to proper (C, 1, 1) arrays so aggregate_stats can process them. episode_stats = {} for key in src_episode_full: if key.startswith("stats/"): @@ -903,15 +903,16 @@ def _copy_and_reindex_episodes_metadata( if feature_name in src_dataset.meta.features: feature_dtype = src_dataset.meta.features[feature_name]["dtype"] if feature_dtype in ["image", "video"] and stat_name != "count": + # Stats are channel-first (C, 1, 1) if isinstance(value, np.ndarray) and value.dtype == object: flat_values = [] for item in value: while isinstance(item, np.ndarray): item = item.flatten()[0] flat_values.append(item) - value = np.array(flat_values, dtype=np.float64).reshape(3, 1, 1) - elif isinstance(value, np.ndarray) and value.shape == (3,): - value = value.reshape(3, 1, 1) + value = np.array(flat_values, dtype=np.float64).reshape(-1, 1, 1) + elif isinstance(value, np.ndarray) and value.ndim == 1: + value = value.reshape(-1, 1, 1) episode_stats[feature_name][stat_name] = value