mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 02:29:47 +00:00
add one-shot script to convert ginwind/VLA-JEPA checkpoints to safetensors (will remove once migrated)
This commit is contained in:
@@ -0,0 +1,279 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Convert all VLA-JEPA .pt checkpoints (ginwind/VLA-JEPA) to LeRobot safetensors
|
||||
format and upload them to maximellerbach org inside a HF collection.
|
||||
|
||||
Usage:
|
||||
uv run python convert_vla_jepa_checkpoints.py
|
||||
|
||||
For each variant the script:
|
||||
1. Downloads the .pt checkpoint.
|
||||
2. Extracts the state dict.
|
||||
3. Instantiates VLAJEPAPolicy with the variant's confirmed config.
|
||||
4. Loads the state dict (strict=False — mismatches printed to stdout).
|
||||
5. push_to_hub → writes model.safetensors + config.json in LeRobot format.
|
||||
6. Adds the new repo to a shared HF collection.
|
||||
|
||||
Config sources
|
||||
--------------
|
||||
Numeric hyper-params : ginwind/VLA-JEPA/<variant>/config.json
|
||||
Image keys LIBERO : lerobot/libero_10 meta/info.json ✓ confirmed
|
||||
Image keys Pretrain : lerobot/droid_1.0.1 meta/info.json ✓ confirmed
|
||||
Image keys SimplerEnv: OXE Bridge/RT1 are single-camera ✓ confirmed
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level settings
|
||||
# ---------------------------------------------------------------------------
|
||||
SOURCE_REPO_ID = "ginwind/VLA-JEPA"
|
||||
TARGET_ORG = "lerobot"
|
||||
COLLECTION_TITLE = "VLA-JEPA"
|
||||
COLLECTION_DESCRIPTION = (
|
||||
"VLA-JEPA model checkpoints (LIBERO, Pretrain, SimplerEnv) converted from .pt to safetensors via LeRobot."
|
||||
)
|
||||
|
||||
# Remap state-dict key prefixes before loading into the LeRobot policy.
|
||||
# E.g. {"": "model."} prepends "model." to every key.
|
||||
# Leave empty if keys already match — the first run's log will tell you.
|
||||
KEY_PREFIX_REMAP: dict[str, str] = {
|
||||
# Specific rules must come before the "" catch-all (dict order is preserved).
|
||||
"qwen_vl_interface.": "model.qwen.",
|
||||
"vj_encoder.": "model.video_encoder.",
|
||||
"vj_predictor.": "model.video_predictor.",
|
||||
# Everything else (action_model.*) just needs the "model." wrapper.
|
||||
"": "model.",
|
||||
}
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Architecture — identical across all 4 variants (from config.json)
|
||||
# ---------------------------------------------------------------------------
|
||||
_ARCH = {
|
||||
"qwen_model_name": "Qwen/Qwen3-VL-2B-Instruct", # 2B, NOT the default 4B
|
||||
"chunk_size": 7,
|
||||
"n_action_steps": 7,
|
||||
"future_action_window_size": 6,
|
||||
"num_video_frames": 8,
|
||||
"jepa_tubelet_size": 2,
|
||||
"num_action_tokens_per_timestep": 8,
|
||||
"num_embodied_action_tokens_per_instruction": 32,
|
||||
"num_inference_timesteps": 4,
|
||||
"action_hidden_size": 1024,
|
||||
"action_model_type": "DiT-B",
|
||||
"action_num_layers": 16,
|
||||
"action_dropout": 0.2,
|
||||
"repeated_diffusion_steps": 8,
|
||||
"action_noise_beta_alpha": 1.5,
|
||||
"action_noise_beta_beta": 1.0,
|
||||
"action_noise_s": 0.999,
|
||||
"action_num_timestep_buckets": 1000,
|
||||
# Action head embedding params (from original config.json)
|
||||
"num_target_vision_tokens": 32,
|
||||
"action_max_seq_len": 1024,
|
||||
# World model predictor (12 blocks, confirmed from checkpoint)
|
||||
"predictor_depth": 12,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image-key sets (confirmed sources in module docstring)
|
||||
# ---------------------------------------------------------------------------
|
||||
# LIBERO — confirmed from lerobot/libero_10 meta/info.json
|
||||
_LIBERO_CAMS = [
|
||||
"observation.images.image", # agentview camera
|
||||
"observation.images.wrist_image", # eye-in-hand camera
|
||||
]
|
||||
|
||||
# DROID pretrain — 2 views match the predictor embed_dim=2 × 1024=2048 in checkpoint
|
||||
_DROID_CAMS = [
|
||||
"observation.images.exterior_1_left",
|
||||
"observation.images.exterior_2_left",
|
||||
]
|
||||
|
||||
# OXE Bridge + RT1 — single-camera; world model disabled (predictor embed_dim mismatch)
|
||||
_OXE_CAMS = [
|
||||
"observation.images.image",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_config(camera_keys: list[str], with_state: bool, enable_world_model: bool = True):
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)) for k in camera_keys}
|
||||
if with_state:
|
||||
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
|
||||
|
||||
cfg = VLAJEPAConfig(
|
||||
input_features=input_features,
|
||||
output_features={
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
},
|
||||
enable_world_model=enable_world_model,
|
||||
**_ARCH,
|
||||
)
|
||||
cfg.validate_features()
|
||||
return cfg
|
||||
|
||||
|
||||
# Maps each subfolder in SOURCE_REPO_ID to (camera_keys, with_state, enable_world_model, repo_suffix)
|
||||
VARIANTS: dict[str, tuple] = {
|
||||
"LIBERO": (_LIBERO_CAMS, True, True, "LIBERO"),
|
||||
"Pretrain": (_DROID_CAMS, False, True, "Pretrain"),
|
||||
# SimplerEnv uses a single camera; the predictor embed_dim (2048) would mismatch, so
|
||||
# disable the world model — only qwen + action_model weights are needed for inference.
|
||||
"SimplerEnv": (_OXE_CAMS, False, False, "SimplerEnv"),
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extract_state_dict(ckpt: object) -> dict[str, torch.Tensor]:
|
||||
if isinstance(ckpt, dict):
|
||||
sd = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt.get("model")
|
||||
if sd is None:
|
||||
sd = ckpt
|
||||
else:
|
||||
sd = ckpt
|
||||
return {k: v for k, v in sd.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
|
||||
def remap_keys(sd: dict[str, torch.Tensor], remap: dict[str, str]) -> dict[str, torch.Tensor]:
|
||||
if not remap:
|
||||
return sd
|
||||
out = {}
|
||||
for k, v in sd.items():
|
||||
new_k = k
|
||||
for old, new in remap.items():
|
||||
if k.startswith(old):
|
||||
new_k = new + k[len(old) :]
|
||||
break
|
||||
out[new_k] = v
|
||||
return out
|
||||
|
||||
|
||||
def subfolder_of(pt_path: str) -> str | None:
|
||||
for part in Path(pt_path).parts:
|
||||
if part in VARIANTS:
|
||||
return part
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
api = HfApi()
|
||||
|
||||
log.info("Listing .pt files in %s …", SOURCE_REPO_ID)
|
||||
pt_files = [f for f in api.list_repo_files(SOURCE_REPO_ID) if f.endswith(".pt")]
|
||||
if not pt_files:
|
||||
log.error("No .pt files found.")
|
||||
return
|
||||
for f in pt_files:
|
||||
log.info(" %s", f)
|
||||
|
||||
# Create / reuse the collection once
|
||||
collection = api.create_collection(
|
||||
title=COLLECTION_TITLE,
|
||||
description=COLLECTION_DESCRIPTION,
|
||||
namespace=TARGET_ORG,
|
||||
exists_ok=True,
|
||||
)
|
||||
log.info("Collection: %s", collection.url)
|
||||
|
||||
for pt_filename in pt_files:
|
||||
log.info("\n=== %s ===", pt_filename)
|
||||
|
||||
subfolder = subfolder_of(pt_filename)
|
||||
if subfolder is None:
|
||||
log.warning(" No variant entry for '%s' — skipping.", pt_filename)
|
||||
continue
|
||||
|
||||
camera_keys, with_state, enable_world_model, repo_suffix = VARIANTS[subfolder]
|
||||
target_repo_id = f"{TARGET_ORG}/VLA-JEPA-{repo_suffix}"
|
||||
|
||||
log.info(
|
||||
" cameras=%d with_state=%s wm=%s → %s",
|
||||
len(camera_keys),
|
||||
with_state,
|
||||
enable_world_model,
|
||||
target_repo_id,
|
||||
)
|
||||
|
||||
# 1. Download
|
||||
local_pt = api.hf_hub_download(SOURCE_REPO_ID, pt_filename)
|
||||
log.info(" Downloaded → %s", local_pt)
|
||||
|
||||
# 2. Load checkpoint
|
||||
try:
|
||||
ckpt = torch.load(local_pt, map_location="cpu", mmap=True, weights_only=False) # nosec B614
|
||||
except TypeError:
|
||||
ckpt = torch.load(local_pt, map_location="cpu") # nosec B614
|
||||
|
||||
sd = extract_state_dict(ckpt)
|
||||
sd = remap_keys(sd, KEY_PREFIX_REMAP)
|
||||
log.info(" %d tensors extracted", len(sd))
|
||||
log.info(" First 5 keys: %s", list(sd)[:5])
|
||||
|
||||
# 3. Build policy
|
||||
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
config = _build_config(camera_keys, with_state, enable_world_model)
|
||||
policy = VLAJEPAPolicy(config)
|
||||
|
||||
# 4. Load weights
|
||||
missing, unexpected = policy.load_state_dict(sd, strict=False)
|
||||
|
||||
def _prefix_summary(keys: list[str]) -> dict[str, int]:
|
||||
from collections import Counter
|
||||
|
||||
return dict(Counter(".".join(k.split(".")[:3]) for k in keys).most_common())
|
||||
|
||||
if missing:
|
||||
log.warning(" Missing (%d) by prefix: %s", len(missing), _prefix_summary(missing))
|
||||
if unexpected:
|
||||
log.warning(" Unexpected (%d) by prefix: %s", len(unexpected), _prefix_summary(unexpected))
|
||||
if not missing and not unexpected:
|
||||
log.info(" State dict loaded cleanly.")
|
||||
|
||||
# 5. Push to hub (writes model.safetensors + config.json)
|
||||
api.create_repo(target_repo_id, repo_type="model", exist_ok=True)
|
||||
commit_url = policy.push_to_hub(
|
||||
repo_id=target_repo_id,
|
||||
commit_message=f"Convert {Path(pt_filename).name} to safetensors",
|
||||
)
|
||||
log.info(" Uploaded → %s", commit_url)
|
||||
|
||||
# 6. Add to collection
|
||||
api.add_collection_item(
|
||||
collection_slug=collection.slug,
|
||||
item_id=target_repo_id,
|
||||
item_type="model",
|
||||
exists_ok=True,
|
||||
)
|
||||
log.info(" Added to collection.")
|
||||
|
||||
log.info("\nAll done. Collection: %s", collection.url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user