mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
267 lines
7.7 KiB
Python
267 lines
7.7 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Aggregate video dataset shards into a single dataset.
|
|
|
|
After parallel conversion using slurm_convert_to_video.py, this script merges
|
|
all the shard datasets into one final dataset.
|
|
|
|
Example usage:
|
|
python slurm_aggregate_video_shards.py \
|
|
--shards-dir /fsx/jade_choghari/libero_video \
|
|
--output-dir /fsx/jade_choghari/libero_video_final \
|
|
--output-repo-id lerobot_video \
|
|
--num-workers 100 \
|
|
--partition cpu_partition \
|
|
--cpus-per-task 16
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
from datatrove.executor import LocalPipelineExecutor
|
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
|
from datatrove.pipeline.base import PipelineStep
|
|
|
|
|
|
class AggregateVideoShards(PipelineStep):
|
|
"""Pipeline step that aggregates video dataset shards."""
|
|
|
|
def __init__(
|
|
self,
|
|
shards_dir: str | Path,
|
|
output_dir: str | Path,
|
|
output_repo_id: str,
|
|
num_shards: int,
|
|
):
|
|
super().__init__()
|
|
self.shards_dir = Path(shards_dir)
|
|
self.output_dir = Path(output_dir)
|
|
self.output_repo_id = output_repo_id
|
|
self.num_shards = num_shards
|
|
|
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
|
"""Aggregate all shards into a single dataset."""
|
|
from lerobot.datasets.dataset_tools import merge_datasets
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.utils.utils import init_logging
|
|
|
|
init_logging()
|
|
|
|
# Only worker 0 performs aggregation
|
|
if rank != 0:
|
|
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
|
|
return
|
|
|
|
logging.info(f"Starting aggregation of {self.num_shards} shards")
|
|
|
|
# Collect all shard datasets
|
|
shard_datasets = []
|
|
for shard_idx in range(self.num_shards):
|
|
shard_dir = self.shards_dir / f"shard_{shard_idx:04d}"
|
|
if not shard_dir.exists():
|
|
logging.warning(f"Shard directory not found: {shard_dir}")
|
|
continue
|
|
|
|
# Find the repo_id for this shard
|
|
shard_repo_id = f"{self.output_repo_id}_shard_{shard_idx:04d}"
|
|
try:
|
|
shard_dataset = LeRobotDataset(shard_repo_id, root=shard_dir)
|
|
shard_datasets.append(shard_dataset)
|
|
logging.info(
|
|
f"Loaded shard {shard_idx}: {shard_dataset.meta.total_episodes} episodes, "
|
|
f"{shard_dataset.meta.total_frames} frames"
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Failed to load shard {shard_idx}: {e}")
|
|
continue
|
|
|
|
if len(shard_datasets) == 0:
|
|
raise ValueError(f"No valid shards found in {self.shards_dir}")
|
|
|
|
logging.info(f"Successfully loaded {len(shard_datasets)} shards, starting merge")
|
|
|
|
# Merge all shards
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
merged_dataset = merge_datasets(
|
|
shard_datasets,
|
|
output_repo_id=self.output_repo_id,
|
|
output_dir=self.output_dir,
|
|
)
|
|
|
|
logging.info("✓ Aggregation complete!")
|
|
logging.info(f"Merged dataset saved to: {self.output_dir}")
|
|
logging.info(f"Total episodes: {merged_dataset.meta.total_episodes}")
|
|
logging.info(f"Total frames: {merged_dataset.meta.total_frames}")
|
|
|
|
|
|
def make_aggregate_executor(
|
|
shards_dir,
|
|
output_dir,
|
|
output_repo_id,
|
|
num_shards,
|
|
job_name,
|
|
logs_dir,
|
|
partition,
|
|
cpus_per_task,
|
|
mem_per_cpu,
|
|
time_limit,
|
|
slurm=True,
|
|
):
|
|
"""Create executor for shard aggregation."""
|
|
kwargs = {
|
|
"pipeline": [
|
|
AggregateVideoShards(
|
|
shards_dir=shards_dir,
|
|
output_dir=output_dir,
|
|
output_repo_id=output_repo_id,
|
|
num_shards=num_shards,
|
|
),
|
|
],
|
|
"logging_dir": str(logs_dir / job_name),
|
|
}
|
|
|
|
if slurm:
|
|
# Only need 1 worker for aggregation
|
|
kwargs.update(
|
|
{
|
|
"job_name": job_name,
|
|
"tasks": 1,
|
|
"workers": 1,
|
|
"time": time_limit,
|
|
"partition": partition,
|
|
"cpus_per_task": cpus_per_task,
|
|
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
|
}
|
|
)
|
|
executor = SlurmPipelineExecutor(**kwargs)
|
|
else:
|
|
kwargs.update(
|
|
{
|
|
"tasks": 1,
|
|
"workers": 1,
|
|
}
|
|
)
|
|
executor = LocalPipelineExecutor(**kwargs)
|
|
|
|
return executor
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Aggregate video dataset shards into a single dataset",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--shards-dir",
|
|
type=Path,
|
|
required=True,
|
|
help="Directory containing shard_XXXX subdirectories",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=Path,
|
|
required=True,
|
|
help="Output directory for the aggregated dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--output-repo-id",
|
|
type=str,
|
|
required=True,
|
|
help="Repository ID for the aggregated dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--num-shards",
|
|
type=int,
|
|
required=True,
|
|
help="Number of shards to aggregate (should match --workers from conversion)",
|
|
)
|
|
parser.add_argument(
|
|
"--logs-dir",
|
|
type=Path,
|
|
required=True,
|
|
help="Path to logs directory for datatrove",
|
|
)
|
|
parser.add_argument(
|
|
"--job-name",
|
|
type=str,
|
|
default="aggregate_video_shards",
|
|
help="Job name for SLURM",
|
|
)
|
|
parser.add_argument(
|
|
"--slurm",
|
|
type=int,
|
|
default=1,
|
|
help="Launch over SLURM (1) or locally (0)",
|
|
)
|
|
parser.add_argument(
|
|
"--partition",
|
|
type=str,
|
|
required=True,
|
|
help="SLURM partition (use CPU partition)",
|
|
)
|
|
parser.add_argument(
|
|
"--cpus-per-task",
|
|
type=int,
|
|
default=16,
|
|
help="Number of CPUs per task (aggregation can use more CPUs)",
|
|
)
|
|
parser.add_argument(
|
|
"--mem-per-cpu",
|
|
type=str,
|
|
default="8G",
|
|
help="Memory per CPU",
|
|
)
|
|
parser.add_argument(
|
|
"--time-limit",
|
|
type=str,
|
|
default="08:00:00",
|
|
help="Time limit for SLURM job",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Convert slurm flag to boolean
|
|
slurm = args.slurm == 1
|
|
|
|
# Create and run executor
|
|
executor = make_aggregate_executor(
|
|
shards_dir=args.shards_dir,
|
|
output_dir=args.output_dir,
|
|
output_repo_id=args.output_repo_id,
|
|
num_shards=args.num_shards,
|
|
job_name=args.job_name,
|
|
logs_dir=args.logs_dir,
|
|
partition=args.partition,
|
|
cpus_per_task=args.cpus_per_task,
|
|
mem_per_cpu=args.mem_per_cpu,
|
|
time_limit=args.time_limit,
|
|
slurm=slurm,
|
|
)
|
|
|
|
logging.info("Starting shard aggregation")
|
|
executor.run()
|
|
logging.info("Aggregation job submitted/completed")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|