add port rlds script

This commit is contained in:
Pepijn
2025-09-08 13:40:47 +02:00
parent af79dda8d9
commit 3d31f2ad53
7 changed files with 2450 additions and 9 deletions
+93 -9
View File
@@ -61,13 +61,71 @@ class PortDroidShards(PipelineStep):
validate_dataset(shard_repo_id)
class PortRLDSShards(PipelineStep):
def __init__(
self,
raw_dir: Path | str,
repo_id: str = None,
num_shards: int = None,
):
super().__init__()
self.raw_dir = Path(raw_dir)
self.repo_id = repo_id
self.num_shards = num_shards
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from port_datasets.port_rlds import port_rlds, validate_dataset
from lerobot.utils.utils import init_logging
init_logging()
disable_progress_bars()
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
try:
validate_dataset(shard_repo_id)
return
except Exception:
pass # nosec B110 - Dataset doesn't exist yet, continue with porting
port_rlds(
self.raw_dir,
shard_repo_id,
push_to_hub=False,
num_shards=world_size,
shard_index=rank,
)
validate_dataset(shard_repo_id)
def make_port_executor(
raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
raw_dir,
repo_id,
job_name,
logs_dir,
workers,
partition,
cpus_per_task,
mem_per_cpu,
slurm=True,
dataset_type="droid",
num_shards=None,
):
# Select appropriate pipeline step based on dataset type
if dataset_type.lower() == "droid":
pipeline_step = PortDroidShards(raw_dir, repo_id)
default_shards = DROID_SHARDS
elif dataset_type.lower() == "rlds":
pipeline_step = PortRLDSShards(raw_dir, repo_id, num_shards)
default_shards = num_shards or workers # Use num_shards or fallback to workers
else:
raise ValueError(f"Unsupported dataset type: {dataset_type}")
kwargs = {
"pipeline": [
PortDroidShards(raw_dir, repo_id),
],
"pipeline": [pipeline_step],
"logging_dir": str(logs_dir / job_name),
}
@@ -75,7 +133,7 @@ def make_port_executor(
kwargs.update(
{
"job_name": job_name,
"tasks": DROID_SHARDS,
"tasks": default_shards,
"workers": workers,
"time": "08:00:00",
"partition": partition,
@@ -115,11 +173,18 @@ def main():
type=Path,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--dataset-type",
type=str,
choices=["droid", "rlds"],
default="droid",
help="Type of dataset to process: 'droid' for DROID datasets or 'rlds' for RLDS/OpenX datasets.",
)
parser.add_argument(
"--job-name",
type=str,
default="port_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
default=None,
help="Job name used in slurm, and name of the directory created inside the provided logs directory. Defaults to 'port_{dataset_type}'.",
)
parser.add_argument(
"--slurm",
@@ -130,8 +195,14 @@ def main():
parser.add_argument(
"--workers",
type=int,
default=2048,
help="Number of slurm workers. It should be less than the maximum number of shards.",
default=None,
help="Number of slurm workers. Defaults: 2048 for DROID, 64 for RLDS datasets.",
)
parser.add_argument(
"--num-shards",
type=int,
default=None,
help="Number of shards to split the dataset into. For DROID datasets, this is fixed at 2048. For RLDS datasets, defaults to number of workers.",
)
parser.add_argument(
"--partition",
@@ -152,8 +223,21 @@ def main():
)
args = parser.parse_args()
# Set defaults based on dataset type
if args.job_name is None:
args.job_name = f"port_{args.dataset_type}"
if args.workers is None:
if args.dataset_type == "droid":
args.workers = 2048
else: # rlds
args.workers = 64
# Convert args to kwargs and process
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
port_executor = make_port_executor(**kwargs)
port_executor.run()