mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
push to specific repo
This commit is contained in:
@@ -41,12 +41,14 @@ class AggregateEgoDexDatasets(PipelineStep):
|
|||||||
aggregated_repo_id: str,
|
aggregated_repo_id: str,
|
||||||
local_dir: Path | str | None = None,
|
local_dir: Path | str | None = None,
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
|
hf_repo_id: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_ids = repo_ids
|
self.repo_ids = repo_ids
|
||||||
self.aggr_repo_id = aggregated_repo_id
|
self.aggr_repo_id = aggregated_repo_id
|
||||||
self.local_dir = Path(local_dir) if local_dir else None
|
self.local_dir = Path(local_dir) if local_dir else None
|
||||||
self.push_to_hub = push_to_hub
|
self.push_to_hub = push_to_hub
|
||||||
|
self.hf_repo_id = hf_repo_id if hf_repo_id else aggregated_repo_id
|
||||||
|
|
||||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
import logging
|
import logging
|
||||||
@@ -95,11 +97,13 @@ class AggregateEgoDexDatasets(PipelineStep):
|
|||||||
|
|
||||||
# Push to Hugging Face Hub if requested
|
# Push to Hugging Face Hub if requested
|
||||||
if self.push_to_hub:
|
if self.push_to_hub:
|
||||||
logging.info(f"Pushing {self.aggr_repo_id} to Hugging Face Hub...")
|
logging.info(f"Pushing to Hugging Face Hub as {self.hf_repo_id}...")
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
repo_id=self.aggr_repo_id,
|
repo_id=self.aggr_repo_id,
|
||||||
root=aggr_root,
|
root=aggr_root,
|
||||||
)
|
)
|
||||||
|
# Update repo_id for pushing to different HF account if specified
|
||||||
|
dataset.repo_id = self.hf_repo_id
|
||||||
dataset.push_to_hub(
|
dataset.push_to_hub(
|
||||||
tags=["egodex", "hand", "dexterous", "lerobot"],
|
tags=["egodex", "hand", "dexterous", "lerobot"],
|
||||||
license="cc-by-nc-nd-4.0",
|
license="cc-by-nc-nd-4.0",
|
||||||
@@ -119,6 +123,7 @@ def make_aggregate_executor(
|
|||||||
mem_per_cpu,
|
mem_per_cpu,
|
||||||
local_dir,
|
local_dir,
|
||||||
push_to_hub,
|
push_to_hub,
|
||||||
|
hf_repo_id,
|
||||||
slurm=True,
|
slurm=True,
|
||||||
):
|
):
|
||||||
"""Create executor for aggregating EgoDex shards."""
|
"""Create executor for aggregating EgoDex shards."""
|
||||||
@@ -127,7 +132,7 @@ def make_aggregate_executor(
|
|||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"pipeline": [
|
"pipeline": [
|
||||||
AggregateEgoDexDatasets(repo_ids, repo_id, local_dir, push_to_hub),
|
AggregateEgoDexDatasets(repo_ids, repo_id, local_dir, push_to_hub, hf_repo_id),
|
||||||
],
|
],
|
||||||
"logging_dir": str(logs_dir / job_name),
|
"logging_dir": str(logs_dir / job_name),
|
||||||
}
|
}
|
||||||
@@ -220,6 +225,12 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Push aggregated dataset to Hugging Face Hub after aggregation.",
|
help="Push aggregated dataset to Hugging Face Hub after aggregation.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-repo-id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Hugging Face repo ID for upload (e.g., username/dataset-name). Defaults to --repo-id.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
kwargs = vars(args)
|
kwargs = vars(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user