mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
Fixes in port droid scripts (#2455)
* Fixes in port droid scripts * revert default mem-per-cpu * style nit * fix relative imports * style nit
This commit is contained in:
@@ -15,16 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.utils.utils import init_logging
|
||||
from port_droid import DROID_SHARDS
|
||||
|
||||
|
||||
class AggregateDatasets(PipelineStep):
|
||||
@@ -38,6 +34,11 @@ class AggregateDatasets(PipelineStep):
|
||||
self.aggr_repo_id = aggregated_repo_id
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
# Since aggregate_datasets already handles parallel processing internally,
|
||||
|
||||
@@ -20,7 +20,7 @@ from pathlib import Path
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
from port_droid import DROID_SHARDS
|
||||
|
||||
|
||||
class PortDroidShards(PipelineStep):
|
||||
@@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep):
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
|
||||
from port_droid import port_droid, validate_dataset
|
||||
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
from port_droid import DROID_SHARDS
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card
|
||||
@@ -185,11 +185,11 @@ class UploadDataset(PipelineStep):
|
||||
|
||||
|
||||
def make_upload_executor(
|
||||
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, private=False, slurm=True
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
UploadDataset(repo_id),
|
||||
UploadDataset(repo_id, private=private),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
@@ -267,6 +267,12 @@ def main():
|
||||
default="1950M",
|
||||
help="Memory per cpu that each worker will use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to create a private repository.",
|
||||
)
|
||||
|
||||
init_logging()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user