Compare commits

...

1 Commits

2 changed files with 103 additions and 70 deletions
+46 -32
View File
@@ -21,6 +21,8 @@ from pathlib import Path
import datasets import datasets
import pandas as pd import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import tqdm import tqdm
from lerobot.datasets.compute_stats import aggregate_stats from lerobot.datasets.compute_stats import aggregate_stats
@@ -35,7 +37,6 @@ from lerobot.datasets.utils import (
get_file_size_in_mb, get_file_size_in_mb,
get_hf_features_from_features, get_hf_features_from_features,
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
to_parquet_with_hf_images,
update_chunk_file_indices, update_chunk_file_indices,
write_info, write_info,
write_stats, write_stats,
@@ -80,28 +81,41 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
return fps, robot_type, features return fps, robot_type, features
def update_data_df(df, src_meta, dst_meta): def update_data_table(table: pa.Table, src_meta, dst_meta) -> pa.Table:
"""Updates a data DataFrame with new indices and task mappings for aggregation. """Updates a pyarrow Table with new indices and task mappings for aggregation.
Adjusts episode indices, frame indices, and task indices to account for Adjusts episode indices, frame indices, and task indices to account for
previously aggregated data in the destination dataset. previously aggregated data in the destination dataset.
Args: Args:
df: DataFrame containing the data to be updated. table: pyarrow Table containing the data to be updated.
src_meta: Source dataset metadata. src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata. dst_meta: Destination dataset metadata.
Returns: Returns:
pd.DataFrame: Updated DataFrame with adjusted indices. pa.Table: Updated Table with adjusted indices.
""" """
ep_offset = dst_meta.info["total_episodes"]
idx_offset = dst_meta.info["total_frames"]
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"] ep_col = table.column("episode_index")
df["index"] = df["index"] + dst_meta.info["total_frames"] new_ep = pa.array([v + ep_offset for v in ep_col.to_pylist()], type=ep_col.type)
table = table.set_column(table.column_names.index("episode_index"), "episode_index", new_ep)
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy()) idx_col = table.column("index")
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy() new_idx = pa.array([v + idx_offset for v in idx_col.to_pylist()], type=idx_col.type)
table = table.set_column(table.column_names.index("index"), "index", new_idx)
return df old_task_indices = table.column("task_index").to_pylist()
src_task_names = src_meta.tasks.index.take(old_task_indices)
new_task_indices = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy().tolist()
table = table.set_column(
table.column_names.index("task_index"),
"task_index",
pa.array(new_task_indices, type=table.schema.field("task_index").type),
)
return table
def update_meta_data( def update_meta_data(
@@ -468,18 +482,13 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
src_path = src_meta.root / DEFAULT_DATA_PATH.format( src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx chunk_index=src_chunk_idx, file_index=src_file_idx
) )
if contains_images: table = pq.read_table(src_path)
# Use HuggingFace datasets to read source data to preserve image format table = update_data_table(table, src_meta, dst_meta)
src_ds = datasets.Dataset.from_parquet(str(src_path))
df = src_ds.to_pandas()
else:
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
# Write data and get the actual destination file it was written to # Write data and get the actual destination file it was written to
# This avoids duplicating the rotation logic here # This avoids duplicating the rotation logic here
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file( data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
df, table,
src_path, src_path,
data_idx, data_idx,
data_files_size_in_mb, data_files_size_in_mb,
@@ -554,8 +563,16 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
return meta_idx return meta_idx
def _write_table_with_hf_images(
table: pa.Table, path: Path, features: datasets.Features | None = None
) -> None:
"""Write a pyarrow Table to parquet with proper HF image encoding."""
ds = datasets.Dataset.from_dict(table.to_pydict(), features=features)
ds.to_parquet(path)
def append_or_create_parquet_file( def append_or_create_parquet_file(
df: pd.DataFrame, data: pd.DataFrame | pa.Table,
src_path: Path, src_path: Path,
idx: dict[str, int], idx: dict[str, int],
max_mb: float, max_mb: float,
@@ -571,7 +588,7 @@ def append_or_create_parquet_file(
from becoming too large. Handles both regular parquet files and those containing images. from becoming too large. Handles both regular parquet files and those containing images.
Args: Args:
df: DataFrame to write to the parquet file. data: Data to write, as a pandas DataFrame or pyarrow Table.
src_path: Path to the source file (used for size estimation). src_path: Path to the source file (used for size estimation).
idx: Dictionary containing current 'chunk' and 'file' indices. idx: Dictionary containing current 'chunk' and 'file' indices.
max_mb: Maximum allowed file size in MB before rotation. max_mb: Maximum allowed file size in MB before rotation.
@@ -585,15 +602,17 @@ def append_or_create_parquet_file(
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
and (dst_chunk, dst_file) is the actual destination file the data was written to. and (dst_chunk, dst_file) is the actual destination file the data was written to.
""" """
table = data if isinstance(data, pa.Table) else pa.Table.from_pandas(data)
dst_chunk, dst_file = idx["chunk"], idx["file"] dst_chunk, dst_file = idx["chunk"], idx["file"]
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
if not dst_path.exists(): if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True) dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images: if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features) _write_table_with_hf_images(table, dst_path, features=hf_features)
else: else:
df.to_parquet(dst_path) pq.write_table(table, dst_path)
return idx, (dst_chunk, dst_file) return idx, (dst_chunk, dst_file)
src_size = get_parquet_file_size_in_mb(src_path) src_size = get_parquet_file_size_in_mb(src_path)
@@ -604,22 +623,17 @@ def append_or_create_parquet_file(
dst_chunk, dst_file = idx["chunk"], idx["file"] dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
new_path.parent.mkdir(parents=True, exist_ok=True) new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df final_table = table
target_path = new_path target_path = new_path
else: else:
if contains_images: existing_table = pq.read_table(dst_path)
# Use HuggingFace datasets to read existing data to preserve image format final_table = pa.concat_tables([existing_table, table], promote_options="permissive")
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
existing_df = existing_ds.to_pandas()
else:
existing_df = pd.read_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True)
target_path = dst_path target_path = dst_path
if contains_images: if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features) _write_table_with_hf_images(final_table, target_path, features=hf_features)
else: else:
final_df.to_parquet(target_path) pq.write_table(final_table, target_path)
return idx, (dst_chunk, dst_file) return idx, (dst_chunk, dst_file)
+57 -38
View File
@@ -32,6 +32,9 @@ from pathlib import Path
import datasets import datasets
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq import pyarrow.parquet as pq
import torch import torch
from tqdm import tqdm from tqdm import tqdm
@@ -496,13 +499,16 @@ def _copy_and_reindex_data(
global_index = 0 global_index = 0
episode_data_metadata: dict[int, dict] = {} episode_data_metadata: dict[int, dict] = {}
episode_keys = list(episode_mapping.keys())
ep_filter = pa_ds.field("episode_index").isin(episode_keys)
if dst_meta.tasks is None: if dst_meta.tasks is None:
all_task_indices = set() all_task_indices: set[int] = set()
for src_path in file_to_episodes: for src_path in file_to_episodes:
df = pd.read_parquet(src_dataset.root / src_path) table = pq.read_table(
mask = df["episode_index"].isin(list(episode_mapping.keys())) src_dataset.root / src_path, columns=["episode_index", "task_index"], filters=ep_filter
task_series: pd.Series = df[mask]["task_index"] )
all_task_indices.update(task_series.unique().tolist()) all_task_indices.update(pc.unique(table.column("task_index")).to_pylist())
tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices] tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices]
dst_meta.save_episode_tasks(list(set(tasks))) dst_meta.save_episode_tasks(list(set(tasks)))
@@ -514,52 +520,41 @@ def _copy_and_reindex_data(
task_mapping[old_task_idx] = new_task_idx task_mapping[old_task_idx] = new_task_idx
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"): for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
df = pd.read_parquet(src_dataset.root / src_path) table = pq.read_table(src_dataset.root / src_path, filters=ep_filter)
all_episodes_in_file = set(df["episode_index"].unique())
episodes_to_keep = file_to_episodes[src_path] episodes_to_keep = file_to_episodes[src_path]
if all_episodes_in_file == episodes_to_keep: if table.num_rows == 0:
df["episode_index"] = df["episode_index"].replace(episode_mapping) continue
df["index"] = range(global_index, global_index + len(df))
df["task_index"] = df["task_index"].replace(task_mapping)
first_ep_old_idx = min(episodes_to_keep) table = _replace_column_values(table, "episode_index", episode_mapping)
src_ep = src_dataset.meta.episodes[first_ep_old_idx] col_pos = table.column_names.index("index")
chunk_idx = src_ep["data/chunk_index"] new_indices = pa.array(range(global_index, global_index + table.num_rows), type=pa.int64())
file_idx = src_ep["data/file_index"] table = table.set_column(col_pos, "index", new_indices)
else: table = _replace_column_values(table, "task_index", task_mapping)
mask = df["episode_index"].isin(list(episode_mapping.keys()))
df = df[mask].copy().reset_index(drop=True)
if len(df) == 0: first_ep_old_idx = min(episodes_to_keep)
continue src_ep = src_dataset.meta.episodes[first_ep_old_idx]
chunk_idx = src_ep["data/chunk_index"]
df["episode_index"] = df["episode_index"].replace(episode_mapping) file_idx = src_ep["data/file_index"]
df["index"] = range(global_index, global_index + len(df))
df["task_index"] = df["task_index"].replace(task_mapping)
first_ep_old_idx = min(episodes_to_keep)
src_ep = src_dataset.meta.episodes[first_ep_old_idx]
chunk_idx = src_ep["data/chunk_index"]
file_idx = src_ep["data/file_index"]
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True) dst_path.parent.mkdir(parents=True, exist_ok=True)
_write_parquet(df, dst_path, dst_meta) _write_parquet(table, dst_path, dst_meta)
ep_col = table.column("episode_index").to_pylist()
idx_col = table.column("index").to_pylist()
for ep_old_idx in episodes_to_keep: for ep_old_idx in episodes_to_keep:
ep_new_idx = episode_mapping[ep_old_idx] ep_new_idx = episode_mapping[ep_old_idx]
ep_df = df[df["episode_index"] == ep_new_idx] ep_indices = [idx_col[i] for i, e in enumerate(ep_col) if e == ep_new_idx]
episode_data_metadata[ep_new_idx] = { episode_data_metadata[ep_new_idx] = {
"data/chunk_index": chunk_idx, "data/chunk_index": chunk_idx,
"data/file_index": file_idx, "data/file_index": file_idx,
"dataset_from_index": int(ep_df["index"].min()), "dataset_from_index": min(ep_indices),
"dataset_to_index": int(ep_df["index"].max() + 1), "dataset_to_index": max(ep_indices) + 1,
} }
global_index += len(df) global_index += table.num_rows
return episode_data_metadata return episode_data_metadata
@@ -910,15 +905,39 @@ def _copy_and_reindex_episodes_metadata(
write_stats(filtered_stats, dst_meta.root) write_stats(filtered_stats, dst_meta.root)
def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -> None: def _replace_column_values(table: pa.Table, column: str, mapping: dict) -> pa.Table:
"""Write DataFrame to parquet """Replace values in a pyarrow Table column using a mapping dict."""
old_values = table.column(column).to_pylist()
new_values = [mapping.get(v, v) for v in old_values]
col_pos = table.column_names.index(column)
return table.set_column(col_pos, column, pa.array(new_values, type=table.schema.field(column).type))
def _write_parquet(
data: pd.DataFrame | pa.Table | dict, path: Path, meta: LeRobotDatasetMetadata
) -> None:
"""Write data to parquet.
This ensures images are properly embedded and the file can be loaded correctly by HF datasets. This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
Args:
data: Input data as a pandas DataFrame, pyarrow Table, or dict of lists.
path: Destination parquet file path.
meta: Dataset metadata for feature schema.
""" """
from lerobot.datasets.utils import embed_images, get_hf_features_from_features from lerobot.datasets.utils import embed_images, get_hf_features_from_features
if isinstance(data, pd.DataFrame):
data_dict = data.to_dict(orient="list")
elif isinstance(data, pa.Table):
data_dict = data.to_pydict()
elif isinstance(data, dict):
data_dict = data
else:
raise TypeError(f"Unsupported data type: {type(data)}")
hf_features = get_hf_features_from_features(meta.features) hf_features = get_hf_features_from_features(meta.features)
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train") ep_dataset = datasets.Dataset.from_dict(data_dict, features=hf_features, split="train")
if len(meta.image_keys) > 0: if len(meta.image_keys) > 0:
ep_dataset = embed_images(ep_dataset) ep_dataset = embed_images(ep_dataset)