mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
Update src/lerobot/datasets/io_utils.py
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com> Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
This commit is contained in:
@@ -278,14 +278,12 @@ def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
|
||||
mirroring the recording writer. ``table`` must carry a contiguous
|
||||
``episode_index`` column.
|
||||
"""
|
||||
episode_index = table.column("episode_index").to_pylist()
|
||||
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
|
||||
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
|
||||
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
|
||||
try:
|
||||
start = 0
|
||||
for i in range(1, len(episode_index) + 1):
|
||||
if i == len(episode_index) or episode_index[i] != episode_index[start]:
|
||||
writer.write_table(table.slice(start, i - start)) # one episode -> one row group
|
||||
start = i
|
||||
for start, stop in zip(starts, np.append(starts[1:], len(episode_index))):
|
||||
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
|
||||
finally:
|
||||
writer.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user