diff --git a/src/lerobot/utils/rabc.py b/src/lerobot/utils/rabc.py index c529f3ccc..dc0c61c69 100644 --- a/src/lerobot/utils/rabc.py +++ b/src/lerobot/utils/rabc.py @@ -20,6 +20,18 @@ from pathlib import Path import numpy as np import pandas as pd import torch +from huggingface_hub import hf_hub_download + + +def resolve_hf_path(path: str | Path) -> Path: + """Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path.""" + path_str = str(path) + if path_str.startswith("hf://datasets/"): + parts = path_str.replace("hf://datasets/", "").split("/") + repo_id = "/".join(parts[:2]) + filename = "/".join(parts[2:]) + return Path(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")) + return Path(path) class RABCWeights: @@ -51,7 +63,7 @@ class RABCWeights: fallback_weight: float = 1.0, device: torch.device = None, ): - self.progress_path = Path(progress_path) + self.progress_path = resolve_hf_path(progress_path) self.chunk_size = chunk_size self.head_mode = head_mode self.kappa = kappa