Compare commits

...

3 Commits

Author SHA1 Message Date
Michel Aractingi 8008cb357d remove bad typing 2025-11-06 09:13:26 +01:00
Michel Aractingi ca5a4a7ae5 add typing hints 2025-11-06 09:12:09 +01:00
Michel Aractingi b5dcd70d2c add embed images in conversion to v3 script; add parquet writer in conversion script 2025-11-05 23:41:38 +01:00
@@ -50,9 +50,9 @@ from typing import Any
import jsonlines
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import tqdm
from datasets import Dataset, Features, Image
from datasets import Dataset, concatenate_datasets
from huggingface_hub import HfApi, snapshot_download
from requests import HTTPError
@@ -68,6 +68,7 @@ from lerobot.datasets.utils import (
LEGACY_EPISODES_STATS_PATH,
LEGACY_TASKS_PATH,
cast_stats_to_numpy,
embed_images,
flatten_dict,
get_file_size_in_mb,
get_parquet_file_size_in_mb,
@@ -174,25 +175,33 @@ def convert_tasks(root, new_root):
write_tasks(df_tasks, new_root)
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys):
# TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
dataframes = [pd.read_parquet(file) for file in paths_to_cat]
# Concatenate all DataFrames along rows
concatenated_df = pd.concat(dataframes, ignore_index=True)
def concat_data_files(
paths_to_cat: list[Path], new_root: Path, chunk_idx: int, file_idx: int, image_keys: list[str]
):
"""Concatenate multiple parquet data files into a single file.
Args:
paths_to_cat: List of parquet file paths to concatenate
new_root: Root directory for the new dataset
chunk_idx: Chunk index for the output file
file_idx: File index within the chunk
image_keys: List of feature keys that contain images
"""
datasets_list: list[Dataset] = [Dataset.from_parquet(str(file)) for file in paths_to_cat]
concatenated_ds: Dataset = concatenate_datasets(datasets_list)
if len(image_keys) > 0:
logging.debug(f"Embedding {len(image_keys)} image features for optimal training performance")
concatenated_ds = embed_images(concatenated_ds)
path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
if len(image_keys) > 0:
schema = pa.Schema.from_pandas(concatenated_df)
features = Features.from_arrow_schema(schema)
for key in image_keys:
features[key] = Image()
schema = features.arrow_schema
else:
schema = None
concatenated_df.to_parquet(path, index=False, schema=schema)
table = concatenated_ds.with_format("arrow")[:]
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
writer.write_table(table)
writer.close()
def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):