fix(copy&reindex): fixing metadat reshaping for single channel frames

This commit is contained in:
CarolinePascal
2026-06-24 17:33:07 +02:00
parent d4cff6b0cc
commit 5c187a06f5
+6 -5
View File
@@ -883,11 +883,11 @@ def _copy_and_reindex_episodes_metadata(
episode_meta.update(video_metadata[new_idx])
# Extract episode statistics from parquet metadata.
# Note (maractingi): When pandas/pyarrow serializes numpy arrays with shape (3, 1, 1) to parquet,
# When pandas/pyarrow serializes numpy arrays with shape (C, 1, 1) to parquet,
# they are being deserialized as nested object arrays like:
# array([array([array([0.])]), array([array([0.])]), array([array([0.])])])
# This happens particularly with image/video statistics. We need to detect and flatten
# these nested structures back to proper (3, 1, 1) arrays so aggregate_stats can process them.
# these nested structures back to proper (C, 1, 1) arrays so aggregate_stats can process them.
episode_stats = {}
for key in src_episode_full:
if key.startswith("stats/"):
@@ -903,15 +903,16 @@ def _copy_and_reindex_episodes_metadata(
if feature_name in src_dataset.meta.features:
feature_dtype = src_dataset.meta.features[feature_name]["dtype"]
if feature_dtype in ["image", "video"] and stat_name != "count":
# Stats are channel-first (C, 1, 1)
if isinstance(value, np.ndarray) and value.dtype == object:
flat_values = []
for item in value:
while isinstance(item, np.ndarray):
item = item.flatten()[0]
flat_values.append(item)
value = np.array(flat_values, dtype=np.float64).reshape(3, 1, 1)
elif isinstance(value, np.ndarray) and value.shape == (3,):
value = value.reshape(3, 1, 1)
value = np.array(flat_values, dtype=np.float64).reshape(-1, 1, 1)
elif isinstance(value, np.ndarray) and value.ndim == 1:
value = value.reshape(-1, 1, 1)
episode_stats[feature_name][stat_name] = value