diff --git a/examples/port_datasets/slurm_aggregate_shards.py b/examples/port_datasets/slurm_aggregate_shards.py index 4e1b71a31..af5473c79 100644 --- a/examples/port_datasets/slurm_aggregate_shards.py +++ b/examples/port_datasets/slurm_aggregate_shards.py @@ -15,16 +15,12 @@ # limitations under the License. import argparse -import logging from pathlib import Path from datatrove.executor import LocalPipelineExecutor from datatrove.executor.slurm import SlurmPipelineExecutor from datatrove.pipeline.base import PipelineStep -from port_datasets.droid_rlds.port_droid import DROID_SHARDS - -from lerobot.datasets.aggregate import aggregate_datasets -from lerobot.utils.utils import init_logging +from port_droid import DROID_SHARDS class AggregateDatasets(PipelineStep): @@ -38,6 +34,11 @@ class AggregateDatasets(PipelineStep): self.aggr_repo_id = aggregated_repo_id def run(self, data=None, rank: int = 0, world_size: int = 1): + import logging + + from lerobot.datasets.aggregate import aggregate_datasets + from lerobot.utils.utils import init_logging + init_logging() # Since aggregate_datasets already handles parallel processing internally, diff --git a/examples/port_datasets/slurm_port_shards.py b/examples/port_datasets/slurm_port_shards.py index 3bb4c135c..657ea870c 100644 --- a/examples/port_datasets/slurm_port_shards.py +++ b/examples/port_datasets/slurm_port_shards.py @@ -20,7 +20,7 @@ from pathlib import Path from datatrove.executor import LocalPipelineExecutor from datatrove.executor.slurm import SlurmPipelineExecutor from datatrove.pipeline.base import PipelineStep -from port_datasets.droid_rlds.port_droid import DROID_SHARDS +from port_droid import DROID_SHARDS class PortDroidShards(PipelineStep): @@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep): def run(self, data=None, rank: int = 0, world_size: int = 1): from datasets.utils.tqdm import disable_progress_bars - from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset + from port_droid import port_droid, validate_dataset from lerobot.utils.utils import init_logging diff --git a/examples/port_datasets/slurm_upload.py b/examples/port_datasets/slurm_upload.py index ade1ef874..55002c0be 100644 --- a/examples/port_datasets/slurm_upload.py +++ b/examples/port_datasets/slurm_upload.py @@ -24,7 +24,7 @@ from datatrove.executor.slurm import SlurmPipelineExecutor from datatrove.pipeline.base import PipelineStep from huggingface_hub import HfApi from huggingface_hub.constants import REPOCARD_NAME -from port_datasets.droid_rlds.port_droid import DROID_SHARDS +from port_droid import DROID_SHARDS from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata from lerobot.datasets.utils import create_lerobot_dataset_card @@ -185,11 +185,11 @@ class UploadDataset(PipelineStep): def make_upload_executor( - repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True + repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, private=False, slurm=True ): kwargs = { "pipeline": [ - UploadDataset(repo_id), + UploadDataset(repo_id, private=private), ], "logging_dir": str(logs_dir / job_name), } @@ -267,6 +267,12 @@ def main(): default="1950M", help="Memory per cpu that each worker will use.", ) + parser.add_argument( + "--private", + action="store_true", + default=False, + help="Whether to create a private repository.", + ) init_logging()