mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 21:50:03 +00:00
more changres
This commit is contained in:
@@ -0,0 +1,266 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Aggregate video dataset shards into a single dataset.
|
||||
|
||||
After parallel conversion using slurm_convert_to_video.py, this script merges
|
||||
all the shard datasets into one final dataset.
|
||||
|
||||
Example usage:
|
||||
python slurm_aggregate_video_shards.py \
|
||||
--shards-dir /fsx/jade_choghari/libero_video \
|
||||
--output-dir /fsx/jade_choghari/libero_video_final \
|
||||
--output-repo-id lerobot_video \
|
||||
--num-workers 100 \
|
||||
--partition cpu_partition \
|
||||
--cpus-per-task 16
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
|
||||
class AggregateVideoShards(PipelineStep):
|
||||
"""Pipeline step that aggregates video dataset shards."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shards_dir: str | Path,
|
||||
output_dir: str | Path,
|
||||
output_repo_id: str,
|
||||
num_shards: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.shards_dir = Path(shards_dir)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_repo_id = output_repo_id
|
||||
self.num_shards = num_shards
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
"""Aggregate all shards into a single dataset."""
|
||||
from lerobot.datasets.dataset_tools import merge_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
# Only worker 0 performs aggregation
|
||||
if rank != 0:
|
||||
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
|
||||
return
|
||||
|
||||
logging.info(f"Starting aggregation of {self.num_shards} shards")
|
||||
|
||||
# Collect all shard datasets
|
||||
shard_datasets = []
|
||||
for shard_idx in range(self.num_shards):
|
||||
shard_dir = self.shards_dir / f"shard_{shard_idx:04d}"
|
||||
if not shard_dir.exists():
|
||||
logging.warning(f"Shard directory not found: {shard_dir}")
|
||||
continue
|
||||
|
||||
# Find the repo_id for this shard
|
||||
shard_repo_id = f"{self.output_repo_id}_shard_{shard_idx:04d}"
|
||||
try:
|
||||
shard_dataset = LeRobotDataset(shard_repo_id, root=shard_dir)
|
||||
shard_datasets.append(shard_dataset)
|
||||
logging.info(
|
||||
f"Loaded shard {shard_idx}: {shard_dataset.meta.total_episodes} episodes, "
|
||||
f"{shard_dataset.meta.total_frames} frames"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load shard {shard_idx}: {e}")
|
||||
continue
|
||||
|
||||
if len(shard_datasets) == 0:
|
||||
raise ValueError(f"No valid shards found in {self.shards_dir}")
|
||||
|
||||
logging.info(f"Successfully loaded {len(shard_datasets)} shards, starting merge")
|
||||
|
||||
# Merge all shards
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
merged_dataset = merge_datasets(
|
||||
shard_datasets,
|
||||
output_repo_id=self.output_repo_id,
|
||||
output_dir=self.output_dir,
|
||||
)
|
||||
|
||||
logging.info("✓ Aggregation complete!")
|
||||
logging.info(f"Merged dataset saved to: {self.output_dir}")
|
||||
logging.info(f"Total episodes: {merged_dataset.meta.total_episodes}")
|
||||
logging.info(f"Total frames: {merged_dataset.meta.total_frames}")
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
shards_dir,
|
||||
output_dir,
|
||||
output_repo_id,
|
||||
num_shards,
|
||||
job_name,
|
||||
logs_dir,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
time_limit,
|
||||
slurm=True,
|
||||
):
|
||||
"""Create executor for shard aggregation."""
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
AggregateVideoShards(
|
||||
shards_dir=shards_dir,
|
||||
output_dir=output_dir,
|
||||
output_repo_id=output_repo_id,
|
||||
num_shards=num_shards,
|
||||
),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
# Only need 1 worker for aggregation
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Aggregate video dataset shards into a single dataset",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shards-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing shard_XXXX subdirectories",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Output directory for the aggregated dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository ID for the aggregated dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-shards",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of shards to aggregate (should match --workers from conversion)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to logs directory for datatrove",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="aggregate_video_shards",
|
||||
help="Job name for SLURM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over SLURM (1) or locally (0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
required=True,
|
||||
help="SLURM partition (use CPU partition)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of CPUs per task (aggregation can use more CPUs)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="8G",
|
||||
help="Memory per CPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time-limit",
|
||||
type=str,
|
||||
default="08:00:00",
|
||||
help="Time limit for SLURM job",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert slurm flag to boolean
|
||||
slurm = args.slurm == 1
|
||||
|
||||
# Create and run executor
|
||||
executor = make_aggregate_executor(
|
||||
shards_dir=args.shards_dir,
|
||||
output_dir=args.output_dir,
|
||||
output_repo_id=args.output_repo_id,
|
||||
num_shards=args.num_shards,
|
||||
job_name=args.job_name,
|
||||
logs_dir=args.logs_dir,
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time_limit,
|
||||
slurm=slurm,
|
||||
)
|
||||
|
||||
logging.info("Starting shard aggregation")
|
||||
executor.run()
|
||||
logging.info("Aggregation job submitted/completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user