mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6906178b39 |
@@ -21,6 +21,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from lerobot.datasets.compute_stats import aggregate_stats
|
from lerobot.datasets.compute_stats import aggregate_stats
|
||||||
@@ -35,7 +37,6 @@ from lerobot.datasets.utils import (
|
|||||||
get_file_size_in_mb,
|
get_file_size_in_mb,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
to_parquet_with_hf_images,
|
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
write_info,
|
write_info,
|
||||||
write_stats,
|
write_stats,
|
||||||
@@ -80,28 +81,41 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
|||||||
return fps, robot_type, features
|
return fps, robot_type, features
|
||||||
|
|
||||||
|
|
||||||
def update_data_df(df, src_meta, dst_meta):
|
def update_data_table(table: pa.Table, src_meta, dst_meta) -> pa.Table:
|
||||||
"""Updates a data DataFrame with new indices and task mappings for aggregation.
|
"""Updates a pyarrow Table with new indices and task mappings for aggregation.
|
||||||
|
|
||||||
Adjusts episode indices, frame indices, and task indices to account for
|
Adjusts episode indices, frame indices, and task indices to account for
|
||||||
previously aggregated data in the destination dataset.
|
previously aggregated data in the destination dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
df: DataFrame containing the data to be updated.
|
table: pyarrow Table containing the data to be updated.
|
||||||
src_meta: Source dataset metadata.
|
src_meta: Source dataset metadata.
|
||||||
dst_meta: Destination dataset metadata.
|
dst_meta: Destination dataset metadata.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
pa.Table: Updated Table with adjusted indices.
|
||||||
"""
|
"""
|
||||||
|
ep_offset = dst_meta.info["total_episodes"]
|
||||||
|
idx_offset = dst_meta.info["total_frames"]
|
||||||
|
|
||||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
ep_col = table.column("episode_index")
|
||||||
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
new_ep = pa.array([v + ep_offset for v in ep_col.to_pylist()], type=ep_col.type)
|
||||||
|
table = table.set_column(table.column_names.index("episode_index"), "episode_index", new_ep)
|
||||||
|
|
||||||
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
idx_col = table.column("index")
|
||||||
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
new_idx = pa.array([v + idx_offset for v in idx_col.to_pylist()], type=idx_col.type)
|
||||||
|
table = table.set_column(table.column_names.index("index"), "index", new_idx)
|
||||||
|
|
||||||
return df
|
old_task_indices = table.column("task_index").to_pylist()
|
||||||
|
src_task_names = src_meta.tasks.index.take(old_task_indices)
|
||||||
|
new_task_indices = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy().tolist()
|
||||||
|
table = table.set_column(
|
||||||
|
table.column_names.index("task_index"),
|
||||||
|
"task_index",
|
||||||
|
pa.array(new_task_indices, type=table.schema.field("task_index").type),
|
||||||
|
)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
def update_meta_data(
|
def update_meta_data(
|
||||||
@@ -468,18 +482,13 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||||
)
|
)
|
||||||
if contains_images:
|
table = pq.read_table(src_path)
|
||||||
# Use HuggingFace datasets to read source data to preserve image format
|
table = update_data_table(table, src_meta, dst_meta)
|
||||||
src_ds = datasets.Dataset.from_parquet(str(src_path))
|
|
||||||
df = src_ds.to_pandas()
|
|
||||||
else:
|
|
||||||
df = pd.read_parquet(src_path)
|
|
||||||
df = update_data_df(df, src_meta, dst_meta)
|
|
||||||
|
|
||||||
# Write data and get the actual destination file it was written to
|
# Write data and get the actual destination file it was written to
|
||||||
# This avoids duplicating the rotation logic here
|
# This avoids duplicating the rotation logic here
|
||||||
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
|
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
|
||||||
df,
|
table,
|
||||||
src_path,
|
src_path,
|
||||||
data_idx,
|
data_idx,
|
||||||
data_files_size_in_mb,
|
data_files_size_in_mb,
|
||||||
@@ -554,8 +563,16 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
return meta_idx
|
return meta_idx
|
||||||
|
|
||||||
|
|
||||||
|
def _write_table_with_hf_images(
|
||||||
|
table: pa.Table, path: Path, features: datasets.Features | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Write a pyarrow Table to parquet with proper HF image encoding."""
|
||||||
|
ds = datasets.Dataset.from_dict(table.to_pydict(), features=features)
|
||||||
|
ds.to_parquet(path)
|
||||||
|
|
||||||
|
|
||||||
def append_or_create_parquet_file(
|
def append_or_create_parquet_file(
|
||||||
df: pd.DataFrame,
|
data: pd.DataFrame | pa.Table,
|
||||||
src_path: Path,
|
src_path: Path,
|
||||||
idx: dict[str, int],
|
idx: dict[str, int],
|
||||||
max_mb: float,
|
max_mb: float,
|
||||||
@@ -571,7 +588,7 @@ def append_or_create_parquet_file(
|
|||||||
from becoming too large. Handles both regular parquet files and those containing images.
|
from becoming too large. Handles both regular parquet files and those containing images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
df: DataFrame to write to the parquet file.
|
data: Data to write, as a pandas DataFrame or pyarrow Table.
|
||||||
src_path: Path to the source file (used for size estimation).
|
src_path: Path to the source file (used for size estimation).
|
||||||
idx: Dictionary containing current 'chunk' and 'file' indices.
|
idx: Dictionary containing current 'chunk' and 'file' indices.
|
||||||
max_mb: Maximum allowed file size in MB before rotation.
|
max_mb: Maximum allowed file size in MB before rotation.
|
||||||
@@ -585,15 +602,17 @@ def append_or_create_parquet_file(
|
|||||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||||
and (dst_chunk, dst_file) is the actual destination file the data was written to.
|
and (dst_chunk, dst_file) is the actual destination file the data was written to.
|
||||||
"""
|
"""
|
||||||
|
table = data if isinstance(data, pa.Table) else pa.Table.from_pandas(data)
|
||||||
|
|
||||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||||
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||||
|
|
||||||
if not dst_path.exists():
|
if not dst_path.exists():
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
_write_table_with_hf_images(table, dst_path, features=hf_features)
|
||||||
else:
|
else:
|
||||||
df.to_parquet(dst_path)
|
pq.write_table(table, dst_path)
|
||||||
return idx, (dst_chunk, dst_file)
|
return idx, (dst_chunk, dst_file)
|
||||||
|
|
||||||
src_size = get_parquet_file_size_in_mb(src_path)
|
src_size = get_parquet_file_size_in_mb(src_path)
|
||||||
@@ -604,22 +623,17 @@ def append_or_create_parquet_file(
|
|||||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||||
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
final_df = df
|
final_table = table
|
||||||
target_path = new_path
|
target_path = new_path
|
||||||
else:
|
else:
|
||||||
if contains_images:
|
existing_table = pq.read_table(dst_path)
|
||||||
# Use HuggingFace datasets to read existing data to preserve image format
|
final_table = pa.concat_tables([existing_table, table], promote_options="permissive")
|
||||||
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
|
|
||||||
existing_df = existing_ds.to_pandas()
|
|
||||||
else:
|
|
||||||
existing_df = pd.read_parquet(dst_path)
|
|
||||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
|
||||||
target_path = dst_path
|
target_path = dst_path
|
||||||
|
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
_write_table_with_hf_images(final_table, target_path, features=hf_features)
|
||||||
else:
|
else:
|
||||||
final_df.to_parquet(target_path)
|
pq.write_table(final_table, target_path)
|
||||||
|
|
||||||
return idx, (dst_chunk, dst_file)
|
return idx, (dst_chunk, dst_file)
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ from pathlib import Path
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.compute as pc
|
||||||
|
import pyarrow.dataset as pa_ds
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -496,13 +499,16 @@ def _copy_and_reindex_data(
|
|||||||
global_index = 0
|
global_index = 0
|
||||||
episode_data_metadata: dict[int, dict] = {}
|
episode_data_metadata: dict[int, dict] = {}
|
||||||
|
|
||||||
|
episode_keys = list(episode_mapping.keys())
|
||||||
|
ep_filter = pa_ds.field("episode_index").isin(episode_keys)
|
||||||
|
|
||||||
if dst_meta.tasks is None:
|
if dst_meta.tasks is None:
|
||||||
all_task_indices = set()
|
all_task_indices: set[int] = set()
|
||||||
for src_path in file_to_episodes:
|
for src_path in file_to_episodes:
|
||||||
df = pd.read_parquet(src_dataset.root / src_path)
|
table = pq.read_table(
|
||||||
mask = df["episode_index"].isin(list(episode_mapping.keys()))
|
src_dataset.root / src_path, columns=["episode_index", "task_index"], filters=ep_filter
|
||||||
task_series: pd.Series = df[mask]["task_index"]
|
)
|
||||||
all_task_indices.update(task_series.unique().tolist())
|
all_task_indices.update(pc.unique(table.column("task_index")).to_pylist())
|
||||||
tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices]
|
tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices]
|
||||||
dst_meta.save_episode_tasks(list(set(tasks)))
|
dst_meta.save_episode_tasks(list(set(tasks)))
|
||||||
|
|
||||||
@@ -514,52 +520,41 @@ def _copy_and_reindex_data(
|
|||||||
task_mapping[old_task_idx] = new_task_idx
|
task_mapping[old_task_idx] = new_task_idx
|
||||||
|
|
||||||
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
|
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
|
||||||
df = pd.read_parquet(src_dataset.root / src_path)
|
table = pq.read_table(src_dataset.root / src_path, filters=ep_filter)
|
||||||
|
|
||||||
all_episodes_in_file = set(df["episode_index"].unique())
|
|
||||||
episodes_to_keep = file_to_episodes[src_path]
|
episodes_to_keep = file_to_episodes[src_path]
|
||||||
|
|
||||||
if all_episodes_in_file == episodes_to_keep:
|
if table.num_rows == 0:
|
||||||
df["episode_index"] = df["episode_index"].replace(episode_mapping)
|
continue
|
||||||
df["index"] = range(global_index, global_index + len(df))
|
|
||||||
df["task_index"] = df["task_index"].replace(task_mapping)
|
|
||||||
|
|
||||||
first_ep_old_idx = min(episodes_to_keep)
|
table = _replace_column_values(table, "episode_index", episode_mapping)
|
||||||
src_ep = src_dataset.meta.episodes[first_ep_old_idx]
|
col_pos = table.column_names.index("index")
|
||||||
chunk_idx = src_ep["data/chunk_index"]
|
new_indices = pa.array(range(global_index, global_index + table.num_rows), type=pa.int64())
|
||||||
file_idx = src_ep["data/file_index"]
|
table = table.set_column(col_pos, "index", new_indices)
|
||||||
else:
|
table = _replace_column_values(table, "task_index", task_mapping)
|
||||||
mask = df["episode_index"].isin(list(episode_mapping.keys()))
|
|
||||||
df = df[mask].copy().reset_index(drop=True)
|
|
||||||
|
|
||||||
if len(df) == 0:
|
first_ep_old_idx = min(episodes_to_keep)
|
||||||
continue
|
src_ep = src_dataset.meta.episodes[first_ep_old_idx]
|
||||||
|
chunk_idx = src_ep["data/chunk_index"]
|
||||||
df["episode_index"] = df["episode_index"].replace(episode_mapping)
|
file_idx = src_ep["data/file_index"]
|
||||||
df["index"] = range(global_index, global_index + len(df))
|
|
||||||
df["task_index"] = df["task_index"].replace(task_mapping)
|
|
||||||
|
|
||||||
first_ep_old_idx = min(episodes_to_keep)
|
|
||||||
src_ep = src_dataset.meta.episodes[first_ep_old_idx]
|
|
||||||
chunk_idx = src_ep["data/chunk_index"]
|
|
||||||
file_idx = src_ep["data/file_index"]
|
|
||||||
|
|
||||||
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
_write_parquet(df, dst_path, dst_meta)
|
_write_parquet(table, dst_path, dst_meta)
|
||||||
|
|
||||||
|
ep_col = table.column("episode_index").to_pylist()
|
||||||
|
idx_col = table.column("index").to_pylist()
|
||||||
for ep_old_idx in episodes_to_keep:
|
for ep_old_idx in episodes_to_keep:
|
||||||
ep_new_idx = episode_mapping[ep_old_idx]
|
ep_new_idx = episode_mapping[ep_old_idx]
|
||||||
ep_df = df[df["episode_index"] == ep_new_idx]
|
ep_indices = [idx_col[i] for i, e in enumerate(ep_col) if e == ep_new_idx]
|
||||||
episode_data_metadata[ep_new_idx] = {
|
episode_data_metadata[ep_new_idx] = {
|
||||||
"data/chunk_index": chunk_idx,
|
"data/chunk_index": chunk_idx,
|
||||||
"data/file_index": file_idx,
|
"data/file_index": file_idx,
|
||||||
"dataset_from_index": int(ep_df["index"].min()),
|
"dataset_from_index": min(ep_indices),
|
||||||
"dataset_to_index": int(ep_df["index"].max() + 1),
|
"dataset_to_index": max(ep_indices) + 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
global_index += len(df)
|
global_index += table.num_rows
|
||||||
|
|
||||||
return episode_data_metadata
|
return episode_data_metadata
|
||||||
|
|
||||||
@@ -910,15 +905,39 @@ def _copy_and_reindex_episodes_metadata(
|
|||||||
write_stats(filtered_stats, dst_meta.root)
|
write_stats(filtered_stats, dst_meta.root)
|
||||||
|
|
||||||
|
|
||||||
def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -> None:
|
def _replace_column_values(table: pa.Table, column: str, mapping: dict) -> pa.Table:
|
||||||
"""Write DataFrame to parquet
|
"""Replace values in a pyarrow Table column using a mapping dict."""
|
||||||
|
old_values = table.column(column).to_pylist()
|
||||||
|
new_values = [mapping.get(v, v) for v in old_values]
|
||||||
|
col_pos = table.column_names.index(column)
|
||||||
|
return table.set_column(col_pos, column, pa.array(new_values, type=table.schema.field(column).type))
|
||||||
|
|
||||||
|
|
||||||
|
def _write_parquet(
|
||||||
|
data: pd.DataFrame | pa.Table | dict, path: Path, meta: LeRobotDatasetMetadata
|
||||||
|
) -> None:
|
||||||
|
"""Write data to parquet.
|
||||||
|
|
||||||
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
|
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Input data as a pandas DataFrame, pyarrow Table, or dict of lists.
|
||||||
|
path: Destination parquet file path.
|
||||||
|
meta: Dataset metadata for feature schema.
|
||||||
"""
|
"""
|
||||||
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
|
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
|
||||||
|
|
||||||
|
if isinstance(data, pd.DataFrame):
|
||||||
|
data_dict = data.to_dict(orient="list")
|
||||||
|
elif isinstance(data, pa.Table):
|
||||||
|
data_dict = data.to_pydict()
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
data_dict = data
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported data type: {type(data)}")
|
||||||
|
|
||||||
hf_features = get_hf_features_from_features(meta.features)
|
hf_features = get_hf_features_from_features(meta.features)
|
||||||
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
|
ep_dataset = datasets.Dataset.from_dict(data_dict, features=hf_features, split="train")
|
||||||
|
|
||||||
if len(meta.image_keys) > 0:
|
if len(meta.image_keys) > 0:
|
||||||
ep_dataset = embed_images(ep_dataset)
|
ep_dataset = embed_images(ep_dataset)
|
||||||
|
|||||||
Reference in New Issue
Block a user