mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
modift in place
This commit is contained in:
@@ -32,7 +32,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -79,71 +78,45 @@ def unify_dataset_tasks(
|
|||||||
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
|
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
|
||||||
logging.info(f"Original tasks: {len(src_meta.tasks)}")
|
logging.info(f"Original tasks: {len(src_meta.tasks)}")
|
||||||
|
|
||||||
# Create output directory
|
# Modify in-place (input_root == output_root supported)
|
||||||
if output_root.exists():
|
data_dir = input_root / DATA_DIR
|
||||||
logging.warning(f"Output directory {output_root} exists, removing it")
|
|
||||||
shutil.rmtree(output_root)
|
|
||||||
|
|
||||||
output_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Copy videos directory (no changes needed)
|
|
||||||
src_videos = input_root / "videos"
|
|
||||||
if src_videos.exists():
|
|
||||||
logging.info("Copying videos...")
|
|
||||||
shutil.copytree(src_videos, output_root / "videos")
|
|
||||||
|
|
||||||
# Process data files - set all task_index to 0
|
# Process data files - set all task_index to 0
|
||||||
logging.info("Processing data files...")
|
logging.info("Processing data files (in-place)...")
|
||||||
src_data_dir = input_root / DATA_DIR
|
for parquet_file in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Processing data"):
|
||||||
dst_data_dir = output_root / DATA_DIR
|
df = pd.read_parquet(parquet_file)
|
||||||
dst_data_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
for src_parquet in tqdm(sorted(src_data_dir.rglob("*.parquet")), desc="Processing data"):
|
|
||||||
rel_path = src_parquet.relative_to(input_root)
|
|
||||||
dst_parquet = output_root / rel_path
|
|
||||||
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
df = pd.read_parquet(src_parquet)
|
|
||||||
df["task_index"] = 0 # All tasks unified to index 0
|
df["task_index"] = 0 # All tasks unified to index 0
|
||||||
df.to_parquet(dst_parquet)
|
df.to_parquet(parquet_file)
|
||||||
|
|
||||||
# Process episodes metadata - set all tasks to unified task
|
# Process episodes metadata - set all tasks to unified task
|
||||||
logging.info("Processing episodes metadata...")
|
logging.info("Processing episodes metadata (in-place)...")
|
||||||
src_episodes_dir = input_root / "meta" / "episodes"
|
episodes_dir = input_root / "meta" / "episodes"
|
||||||
dst_episodes_dir = output_root / "meta" / "episodes"
|
if episodes_dir.exists():
|
||||||
dst_episodes_dir.mkdir(parents=True, exist_ok=True)
|
for parquet_file in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Processing episodes"):
|
||||||
|
df = pd.read_parquet(parquet_file)
|
||||||
|
df["tasks"] = [[UNIFIED_TASK]] * len(df) # All episodes get the unified task
|
||||||
|
df.to_parquet(parquet_file)
|
||||||
|
else:
|
||||||
|
logging.warning(f"No episodes directory found at {episodes_dir}, skipping")
|
||||||
|
|
||||||
for src_parquet in tqdm(sorted(src_episodes_dir.rglob("*.parquet")), desc="Processing episodes"):
|
# Update tasks.parquet with single task
|
||||||
rel_path = src_parquet.relative_to(src_episodes_dir)
|
|
||||||
dst_parquet = dst_episodes_dir / rel_path
|
|
||||||
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
df = pd.read_parquet(src_parquet)
|
|
||||||
df["tasks"] = [[UNIFIED_TASK]] * len(df) # All episodes get the unified task
|
|
||||||
df.to_parquet(dst_parquet)
|
|
||||||
|
|
||||||
# Create new tasks.parquet with single task
|
|
||||||
logging.info(f"Creating single task: {UNIFIED_TASK}")
|
logging.info(f"Creating single task: {UNIFIED_TASK}")
|
||||||
new_tasks = pd.DataFrame({"task_index": [0]}, index=[UNIFIED_TASK])
|
new_tasks = pd.DataFrame({"task_index": [0]}, index=[UNIFIED_TASK])
|
||||||
write_tasks(new_tasks, output_root)
|
write_tasks(new_tasks, input_root)
|
||||||
|
|
||||||
# Update info.json
|
# Update info.json
|
||||||
new_info = src_meta.info.copy()
|
new_info = src_meta.info.copy()
|
||||||
new_info["total_tasks"] = 1
|
new_info["total_tasks"] = 1
|
||||||
write_info(new_info, output_root)
|
write_info(new_info, input_root)
|
||||||
|
|
||||||
# Copy stats.json (unchanged)
|
logging.info(f"Dataset modified in-place at {input_root}")
|
||||||
if src_meta.stats:
|
|
||||||
write_stats(src_meta.stats, output_root)
|
|
||||||
|
|
||||||
logging.info(f"Dataset saved to {output_root}")
|
|
||||||
logging.info(f"Task: {UNIFIED_TASK}")
|
logging.info(f"Task: {UNIFIED_TASK}")
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
logging.info(f"Pushing {output_repo_id} to hub")
|
logging.info(f"Pushing {input_repo_id} to hub")
|
||||||
dataset = LeRobotDataset(output_repo_id, root=output_root)
|
dataset = LeRobotDataset(input_repo_id, root=input_root)
|
||||||
dataset.push_to_hub(private=True)
|
dataset.push_to_hub(private=True)
|
||||||
logging.info("Push complete!")
|
logging.info("Push complete!")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user