add push to hub

This commit is contained in:
Pepijn
2026-02-04 19:17:40 +01:00
parent a9bce4732b
commit d26349c692
+10 -1
View File
@@ -154,6 +154,7 @@ def mirror_dataset(
mirroring_mask: dict[str, int] | None = None,
vcodec: str = "libsvtav1",
num_workers: int | None = None,
push_to_hub: bool = False,
) -> LeRobotDataset:
"""Mirror a bimanual robot dataset."""
logger.info(f"Loading dataset: {repo_id}")
@@ -193,7 +194,13 @@ def mirror_dataset(
_copy_episodes_metadata(dataset, new_meta)
logger.info(f"Mirrored dataset saved to: {output_root}")
return LeRobotDataset(output_repo_id, root=output_root)
mirrored_dataset = LeRobotDataset(output_repo_id, root=output_root)
if push_to_hub:
logger.info(f"Pushing mirrored dataset to hub: {output_repo_id}")
mirrored_dataset.push_to_hub()
return mirrored_dataset
def _mirror_data(
@@ -342,6 +349,7 @@ def main():
parser.add_argument("--output_root", type=str, default=None, help="Output dataset root directory")
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec (libsvtav1, h264, hevc)")
parser.add_argument("--num_workers", type=int, default=None, help="Number of parallel workers for video processing")
parser.add_argument("--push_to_hub", action="store_true", help="Push mirrored dataset to HuggingFace Hub")
args = parser.parse_args()
mirror_dataset(
@@ -351,6 +359,7 @@ def main():
output_root=args.output_root,
vcodec=args.vcodec,
num_workers=args.num_workers,
push_to_hub=args.push_to_hub,
)