mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
Fixes in port droid scripts (#2455)
* Fixes in port droid scripts * revert default mem-per-cpu * style nit * fix relative imports * style nit
This commit is contained in:
@@ -15,16 +15,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from datatrove.executor import LocalPipelineExecutor
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||||
from datatrove.pipeline.base import PipelineStep
|
from datatrove.pipeline.base import PipelineStep
|
||||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
from port_droid import DROID_SHARDS
|
||||||
|
|
||||||
from lerobot.datasets.aggregate import aggregate_datasets
|
|
||||||
from lerobot.utils.utils import init_logging
|
|
||||||
|
|
||||||
|
|
||||||
class AggregateDatasets(PipelineStep):
|
class AggregateDatasets(PipelineStep):
|
||||||
@@ -38,6 +34,11 @@ class AggregateDatasets(PipelineStep):
|
|||||||
self.aggr_repo_id = aggregated_repo_id
|
self.aggr_repo_id = aggregated_repo_id
|
||||||
|
|
||||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from lerobot.datasets.aggregate import aggregate_datasets
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
# Since aggregate_datasets already handles parallel processing internally,
|
# Since aggregate_datasets already handles parallel processing internally,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from pathlib import Path
|
|||||||
from datatrove.executor import LocalPipelineExecutor
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||||
from datatrove.pipeline.base import PipelineStep
|
from datatrove.pipeline.base import PipelineStep
|
||||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
from port_droid import DROID_SHARDS
|
||||||
|
|
||||||
|
|
||||||
class PortDroidShards(PipelineStep):
|
class PortDroidShards(PipelineStep):
|
||||||
@@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep):
|
|||||||
|
|
||||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
from datasets.utils.tqdm import disable_progress_bars
|
from datasets.utils.tqdm import disable_progress_bars
|
||||||
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
|
from port_droid import port_droid, validate_dataset
|
||||||
|
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from datatrove.executor.slurm import SlurmPipelineExecutor
|
|||||||
from datatrove.pipeline.base import PipelineStep
|
from datatrove.pipeline.base import PipelineStep
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.constants import REPOCARD_NAME
|
from huggingface_hub.constants import REPOCARD_NAME
|
||||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
from port_droid import DROID_SHARDS
|
||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||||
from lerobot.datasets.utils import create_lerobot_dataset_card
|
from lerobot.datasets.utils import create_lerobot_dataset_card
|
||||||
@@ -185,11 +185,11 @@ class UploadDataset(PipelineStep):
|
|||||||
|
|
||||||
|
|
||||||
def make_upload_executor(
|
def make_upload_executor(
|
||||||
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, private=False, slurm=True
|
||||||
):
|
):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"pipeline": [
|
"pipeline": [
|
||||||
UploadDataset(repo_id),
|
UploadDataset(repo_id, private=private),
|
||||||
],
|
],
|
||||||
"logging_dir": str(logs_dir / job_name),
|
"logging_dir": str(logs_dir / job_name),
|
||||||
}
|
}
|
||||||
@@ -267,6 +267,12 @@ def main():
|
|||||||
default="1950M",
|
default="1950M",
|
||||||
help="Memory per cpu that each worker will use.",
|
help="Memory per cpu that each worker will use.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--private",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to create a private repository.",
|
||||||
|
)
|
||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user