mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
add port rlds script
This commit is contained in:
@@ -61,13 +61,71 @@ class PortDroidShards(PipelineStep):
|
||||
validate_dataset(shard_repo_id)
|
||||
|
||||
|
||||
class PortRLDSShards(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
raw_dir: Path | str,
|
||||
repo_id: str = None,
|
||||
num_shards: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.raw_dir = Path(raw_dir)
|
||||
self.repo_id = repo_id
|
||||
self.num_shards = num_shards
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
from port_datasets.port_rlds import port_rlds, validate_dataset
|
||||
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
|
||||
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
|
||||
|
||||
try:
|
||||
validate_dataset(shard_repo_id)
|
||||
return
|
||||
except Exception:
|
||||
pass # nosec B110 - Dataset doesn't exist yet, continue with porting
|
||||
|
||||
port_rlds(
|
||||
self.raw_dir,
|
||||
shard_repo_id,
|
||||
push_to_hub=False,
|
||||
num_shards=world_size,
|
||||
shard_index=rank,
|
||||
)
|
||||
|
||||
validate_dataset(shard_repo_id)
|
||||
|
||||
|
||||
def make_port_executor(
|
||||
raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||
raw_dir,
|
||||
repo_id,
|
||||
job_name,
|
||||
logs_dir,
|
||||
workers,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
slurm=True,
|
||||
dataset_type="droid",
|
||||
num_shards=None,
|
||||
):
|
||||
# Select appropriate pipeline step based on dataset type
|
||||
if dataset_type.lower() == "droid":
|
||||
pipeline_step = PortDroidShards(raw_dir, repo_id)
|
||||
default_shards = DROID_SHARDS
|
||||
elif dataset_type.lower() == "rlds":
|
||||
pipeline_step = PortRLDSShards(raw_dir, repo_id, num_shards)
|
||||
default_shards = num_shards or workers # Use num_shards or fallback to workers
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset type: {dataset_type}")
|
||||
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
PortDroidShards(raw_dir, repo_id),
|
||||
],
|
||||
"pipeline": [pipeline_step],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
@@ -75,7 +133,7 @@ def make_port_executor(
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": DROID_SHARDS,
|
||||
"tasks": default_shards,
|
||||
"workers": workers,
|
||||
"time": "08:00:00",
|
||||
"partition": partition,
|
||||
@@ -115,11 +173,18 @@ def main():
|
||||
type=Path,
|
||||
help="Path to logs directory for `datatrove`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-type",
|
||||
type=str,
|
||||
choices=["droid", "rlds"],
|
||||
default="droid",
|
||||
help="Type of dataset to process: 'droid' for DROID datasets or 'rlds' for RLDS/OpenX datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="port_droid",
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||
default=None,
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory. Defaults to 'port_{dataset_type}'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
@@ -130,8 +195,14 @@ def main():
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||
default=None,
|
||||
help="Number of slurm workers. Defaults: 2048 for DROID, 64 for RLDS datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-shards",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of shards to split the dataset into. For DROID datasets, this is fixed at 2048. For RLDS datasets, defaults to number of workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
@@ -152,8 +223,21 @@ def main():
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set defaults based on dataset type
|
||||
if args.job_name is None:
|
||||
args.job_name = f"port_{args.dataset_type}"
|
||||
|
||||
if args.workers is None:
|
||||
if args.dataset_type == "droid":
|
||||
args.workers = 2048
|
||||
else: # rlds
|
||||
args.workers = 64
|
||||
|
||||
# Convert args to kwargs and process
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
|
||||
port_executor = make_port_executor(**kwargs)
|
||||
port_executor.run()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user