mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fix: debug aggregation code
This commit is contained in:
committed by
Michel Aractingi
parent
d4fbf6ef39
commit
378c147be6
@@ -5,9 +5,8 @@ from pathlib import Path
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
@@ -23,7 +22,6 @@ from lerobot.common.datasets.utils import (
|
|||||||
write_stats,
|
write_stats,
|
||||||
write_tasks,
|
write_tasks,
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import init_logging
|
|
||||||
|
|
||||||
|
|
||||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||||
@@ -50,8 +48,8 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
|||||||
|
|
||||||
def update_data_df(df, src_meta, dst_meta):
|
def update_data_df(df, src_meta, dst_meta):
|
||||||
def _update(row):
|
def _update(row):
|
||||||
row["episode_index"] = row["episode_index"] + dst_meta["total_episodes"]
|
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||||
row["index"] = row["index"] + dst_meta["total_frames"]
|
row["index"] = row["index"] + dst_meta.info["total_frames"]
|
||||||
task = src_meta.tasks.iloc[row["task_index"]].name
|
task = src_meta.tasks.iloc[row["task_index"]].name
|
||||||
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
|
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
|
||||||
return row
|
return row
|
||||||
@@ -67,13 +65,13 @@ def update_meta_data(
|
|||||||
videos_idx,
|
videos_idx,
|
||||||
):
|
):
|
||||||
def _update(row):
|
def _update(row):
|
||||||
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk_index"]
|
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||||
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file_index"]
|
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
|
||||||
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk_index"]
|
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"]
|
||||||
row["data/file_index"] = row["data/file_index"] + data_idx["file_index"]
|
row["data/file_index"] = row["data/file_index"] + data_idx["file"]
|
||||||
for key, video_idx in videos_idx.items():
|
for key, video_idx in videos_idx.items():
|
||||||
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk_index"]
|
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"]
|
||||||
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file_index"]
|
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"]
|
||||||
row[f"videos/{key}/from_timestamp"] = (
|
row[f"videos/{key}/from_timestamp"] = (
|
||||||
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||||
)
|
)
|
||||||
@@ -100,7 +98,11 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||||
|
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||||
video_keys = [k for k, v in features.items() if v["dtype"] == "video"]
|
video_keys = [k for k, v in features.items() if v["dtype"] == "video"]
|
||||||
|
=======
|
||||||
|
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||||
|
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||||
|
|
||||||
# Initialize output dataset metadata
|
# Initialize output dataset metadata
|
||||||
dst_meta = LeRobotDatasetMetadata.create(
|
dst_meta = LeRobotDatasetMetadata.create(
|
||||||
@@ -127,7 +129,12 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx)
|
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx)
|
||||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx)
|
data_idx = aggregate_data(src_meta, dst_meta, data_idx)
|
||||||
|
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys)
|
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys)
|
||||||
|
=======
|
||||||
|
|
||||||
|
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||||
|
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||||
|
|
||||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||||
@@ -157,8 +164,8 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Current target chunk/file index
|
# Current target chunk/file index
|
||||||
chunk_idx = video_idx["chunk_idx"]
|
chunk_idx = video_idx["chunk"]
|
||||||
file_idx = video_idx["file_idx"]
|
file_idx = video_idx["file"]
|
||||||
|
|
||||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||||
@@ -203,11 +210,17 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
|
|||||||
file_idx,
|
file_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||||
# Update the video index tracking
|
# Update the video index tracking
|
||||||
video_idx["chunk_idx"] = chunk_idx
|
video_idx["chunk_idx"] = chunk_idx
|
||||||
video_idx["file_idx"] = file_idx
|
video_idx["file_idx"] = file_idx
|
||||||
|
=======
|
||||||
|
# Update the videos_idx with the final chunk and file indices for this key
|
||||||
|
videos_idx[key]["chunk"] = chunk_idx
|
||||||
|
videos_idx[key]["file"] = file_idx
|
||||||
|
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||||
|
|
||||||
return videos_idx
|
return videos_idx
|
||||||
|
|
||||||
|
|
||||||
def aggregate_data(src_meta, dst_meta, data_idx):
|
def aggregate_data(src_meta, dst_meta, data_idx):
|
||||||
@@ -235,6 +248,11 @@ def aggregate_data(src_meta, dst_meta, data_idx):
|
|||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
|
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||||
|
=======
|
||||||
|
contains_images=len(dst_meta.image_keys) > 0,
|
||||||
|
aggr_root=dst_meta.root,
|
||||||
|
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_idx
|
return data_idx
|
||||||
@@ -261,13 +279,17 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
videos_idx,
|
videos_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
# for k in video_keys:
|
for k in videos_idx:
|
||||||
# video_idx[k]["latest_duration"] += video_idx[k]["episode_duration"]
|
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||||
|
|
||||||
|
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||||
dst_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(
|
dst_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(
|
||||||
chunk_index=meta_idx["chunk"], file_index=meta_idx["file"]
|
chunk_index=meta_idx["chunk"], file_index=meta_idx["file"]
|
||||||
)
|
)
|
||||||
write_parquet_safely(
|
write_parquet_safely(
|
||||||
|
=======
|
||||||
|
meta_idx = append_or_create_parquet_file(
|
||||||
|
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||||
df,
|
df,
|
||||||
src_path,
|
src_path,
|
||||||
dst_path,
|
dst_path,
|
||||||
@@ -275,6 +297,8 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
|
contains_images=False,
|
||||||
|
aggr_root=dst_meta.root,
|
||||||
)
|
)
|
||||||
|
|
||||||
return meta_idx
|
return meta_idx
|
||||||
@@ -288,6 +312,11 @@ def write_parquet_safely(
|
|||||||
max_mb: float,
|
max_mb: float,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
default_path: str,
|
default_path: str,
|
||||||
|
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||||
|
=======
|
||||||
|
contains_images: bool = False,
|
||||||
|
aggr_root: Path = None,
|
||||||
|
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Safely appends or creates a Parquet file at dst_path based on size constraints.
|
Safely appends or creates a Parquet file at dst_path based on size constraints.
|
||||||
@@ -304,11 +333,19 @@ def write_parquet_safely(
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Updated index dictionary.
|
dict: Updated index dictionary.
|
||||||
"""
|
"""
|
||||||
|
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||||
|
=======
|
||||||
|
# Initial destination path - use the correct default_path parameter
|
||||||
|
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||||
|
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||||
|
|
||||||
# If destination file doesn't exist, just write the new one
|
# If destination file doesn't exist, just write the new one
|
||||||
if not dst_path.exists():
|
if not dst_path.exists():
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
df.to_parquet(dst_path)
|
if contains_images:
|
||||||
|
to_parquet_with_hf_images(df, dst_path)
|
||||||
|
else:
|
||||||
|
df.to_parquet(dst_path)
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
# Otherwise, check if we exceed the size limit
|
# Otherwise, check if we exceed the size limit
|
||||||
@@ -318,14 +355,29 @@ def write_parquet_safely(
|
|||||||
if dst_size + src_size >= max_mb:
|
if dst_size + src_size >= max_mb:
|
||||||
# File is too large, move to a new one
|
# File is too large, move to a new one
|
||||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||||
new_path = dst_path.parent / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||||
df.to_parquet(new_path)
|
df.to_parquet(new_path)
|
||||||
else:
|
else:
|
||||||
# Append to existing file
|
# Append to existing file
|
||||||
existing_df = pd.read_parquet(dst_path)
|
existing_df = pd.read_parquet(dst_path)
|
||||||
combined_df = pd.concat([existing_df, df], ignore_index=True)
|
combined_df = pd.concat([existing_df, df], ignore_index=True)
|
||||||
combined_df.to_parquet(dst_path)
|
combined_df.to_parquet(dst_path)
|
||||||
|
=======
|
||||||
|
final_df = df
|
||||||
|
target_path = new_path
|
||||||
|
else:
|
||||||
|
# Append to existing file
|
||||||
|
existing_df = pd.read_parquet(dst_path)
|
||||||
|
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||||
|
target_path = dst_path
|
||||||
|
|
||||||
|
if contains_images:
|
||||||
|
to_parquet_with_hf_images(final_df, target_path)
|
||||||
|
else:
|
||||||
|
final_df.to_parquet(target_path)
|
||||||
|
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||||
|
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
@@ -348,52 +400,3 @@ def finalize_aggregation(aggr_meta, all_metadata):
|
|||||||
logging.info("write stats")
|
logging.info("write stats")
|
||||||
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
|
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
|
||||||
write_stats(aggr_meta.stats, aggr_meta.root)
|
write_stats(aggr_meta.stats, aggr_meta.root)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
init_logging()
|
|
||||||
|
|
||||||
num_shards = 2048
|
|
||||||
repo_id = "cadene/droid_1.0.1_v30"
|
|
||||||
aggr_repo_id = f"{repo_id}_compact_6"
|
|
||||||
tags = ["openx"]
|
|
||||||
|
|
||||||
# num_shards = 210
|
|
||||||
# repo_id = "cadene/agibot_alpha_v30"
|
|
||||||
# aggr_repo_id = f"{repo_id}"
|
|
||||||
# tags = None
|
|
||||||
|
|
||||||
# aggr_root = Path(f"/tmp/{aggr_repo_id}")
|
|
||||||
aggr_root = HF_LEROBOT_HOME / aggr_repo_id
|
|
||||||
if aggr_root.exists():
|
|
||||||
shutil.rmtree(aggr_root)
|
|
||||||
|
|
||||||
repo_ids = []
|
|
||||||
roots = []
|
|
||||||
for rank in range(num_shards):
|
|
||||||
shard_repo_id = f"{repo_id}_world_{num_shards}_rank_{rank}"
|
|
||||||
shard_root = HF_LEROBOT_HOME / shard_repo_id
|
|
||||||
try:
|
|
||||||
meta = LeRobotDatasetMetadata(shard_repo_id, root=shard_root)
|
|
||||||
if len(meta.video_keys) == 0:
|
|
||||||
continue
|
|
||||||
repo_ids.append(shard_repo_id)
|
|
||||||
roots.append(shard_root)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if rank == 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
aggregate_datasets(
|
|
||||||
repo_ids,
|
|
||||||
aggr_repo_id,
|
|
||||||
roots=roots,
|
|
||||||
aggr_root=aggr_root,
|
|
||||||
)
|
|
||||||
|
|
||||||
aggr_dataset = LeRobotDataset(repo_id=aggr_repo_id, root=aggr_root)
|
|
||||||
# for i in tqdm.tqdm(range(len(aggr_dataset))):
|
|
||||||
# aggr_dataset[i]
|
|
||||||
# pass
|
|
||||||
aggr_dataset.push_to_hub(tags=tags, upload_large_folder=True)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user