diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py index 4beee4686..34da75312 100644 --- a/src/lerobot/datasets/io_utils.py +++ b/src/lerobot/datasets/io_utils.py @@ -278,14 +278,12 @@ def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None: mirroring the recording writer. ``table`` must carry a contiguous ``episode_index`` column. """ - episode_index = table.column("episode_index").to_pylist() + episode_index = table.column("episode_index").to_numpy(zero_copy_only=False) + starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1)) writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True) try: - start = 0 - for i in range(1, len(episode_index) + 1): - if i == len(episode_index) or episode_index[i] != episode_index[start]: - writer.write_table(table.slice(start, i - start)) # one episode -> one row group - start = i + for start, stop in zip(starts, np.append(starts[1:], len(episode_index))): + writer.write_table(table.slice(start, stop - start)) # one episode -> one row group finally: writer.close()