mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 16:17:15 +00:00
fix(pi052): restore normalizer stats when loading from a hub repo id
_restore_pi052_pretrained_state did `Path(pretrained_path).exists()` and returned early for HF repo ids (only local dirs passed), so pi052 policies loaded via --policy.path=<repo_id> ran with fresh-init (un-normalized) quantile stats — state fed raw and actions never unnormalized, giving ~0% success. Resolve the repo id via snapshot_download (processor files only) so the saved normalizer/unnormalizer safetensors are transplanted as intended. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -92,11 +92,33 @@ def _restore_pi052_pretrained_state(
|
||||
|
||||
from safetensors.torch import load_file # noqa: PLC0415
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
base = Path(pretrained_path)
|
||||
if not base.exists():
|
||||
return
|
||||
# ``pretrained_path`` may be a HF Hub repo id rather than a local dir.
|
||||
# ``from_pretrained`` downloads the model weights, but pi052 builds its
|
||||
# processors fresh (so the generic loader never fetches them), leaving
|
||||
# the processor JSON + normalizer-stat safetensors un-downloaded. Resolve
|
||||
# them from the hub here — otherwise the quantile stats are silently left
|
||||
# at fresh init and the policy runs completely un-normalized.
|
||||
try:
|
||||
from huggingface_hub import snapshot_download # noqa: PLC0415
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
base = Path(
|
||||
snapshot_download(
|
||||
repo_id=str(pretrained_path),
|
||||
allow_patterns=["policy_preprocessor*", "policy_postprocessor*"],
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"PI052 state restore: %s is not a local dir and could not be resolved "
|
||||
"as a hub repo (%s); normalizer stats left at fresh init",
|
||||
pretrained_path,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
for pipeline, config_filename in [
|
||||
(preprocessor, f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"),
|
||||
|
||||
Reference in New Issue
Block a user