From cee80daa886b8152e084efefb9b615aaae2604c3 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Fri, 15 May 2026 15:59:43 +0200 Subject: [PATCH] add one-shot script to convert ginwind/VLA-JEPA checkpoints to safetensors (will remove once migrated) --- .../vla_jepa/convert_vla_jepa_checkpoints.py | 279 ++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py diff --git a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py new file mode 100644 index 000000000..5f60c97b1 --- /dev/null +++ b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python +""" +Convert all VLA-JEPA .pt checkpoints (ginwind/VLA-JEPA) to LeRobot safetensors +format and upload them to maximellerbach org inside a HF collection. + +Usage: + uv run python convert_vla_jepa_checkpoints.py + +For each variant the script: + 1. Downloads the .pt checkpoint. + 2. Extracts the state dict. + 3. Instantiates VLAJEPAPolicy with the variant's confirmed config. + 4. Loads the state dict (strict=False — mismatches printed to stdout). + 5. push_to_hub → writes model.safetensors + config.json in LeRobot format. + 6. Adds the new repo to a shared HF collection. + +Config sources +-------------- +Numeric hyper-params : ginwind/VLA-JEPA//config.json +Image keys LIBERO : lerobot/libero_10 meta/info.json ✓ confirmed +Image keys Pretrain : lerobot/droid_1.0.1 meta/info.json ✓ confirmed +Image keys SimplerEnv: OXE Bridge/RT1 are single-camera ✓ confirmed +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import torch +from huggingface_hub import HfApi + +# --------------------------------------------------------------------------- +# Top-level settings +# --------------------------------------------------------------------------- +SOURCE_REPO_ID = "ginwind/VLA-JEPA" +TARGET_ORG = "lerobot" +COLLECTION_TITLE = "VLA-JEPA" +COLLECTION_DESCRIPTION = ( + "VLA-JEPA model checkpoints (LIBERO, Pretrain, SimplerEnv) converted from .pt to safetensors via LeRobot." +) + +# Remap state-dict key prefixes before loading into the LeRobot policy. +# E.g. {"": "model."} prepends "model." to every key. +# Leave empty if keys already match — the first run's log will tell you. +KEY_PREFIX_REMAP: dict[str, str] = { + # Specific rules must come before the "" catch-all (dict order is preserved). + "qwen_vl_interface.": "model.qwen.", + "vj_encoder.": "model.video_encoder.", + "vj_predictor.": "model.video_predictor.", + # Everything else (action_model.*) just needs the "model." wrapper. + "": "model.", +} + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Architecture — identical across all 4 variants (from config.json) +# --------------------------------------------------------------------------- +_ARCH = { + "qwen_model_name": "Qwen/Qwen3-VL-2B-Instruct", # 2B, NOT the default 4B + "chunk_size": 7, + "n_action_steps": 7, + "future_action_window_size": 6, + "num_video_frames": 8, + "jepa_tubelet_size": 2, + "num_action_tokens_per_timestep": 8, + "num_embodied_action_tokens_per_instruction": 32, + "num_inference_timesteps": 4, + "action_hidden_size": 1024, + "action_model_type": "DiT-B", + "action_num_layers": 16, + "action_dropout": 0.2, + "repeated_diffusion_steps": 8, + "action_noise_beta_alpha": 1.5, + "action_noise_beta_beta": 1.0, + "action_noise_s": 0.999, + "action_num_timestep_buckets": 1000, + # Action head embedding params (from original config.json) + "num_target_vision_tokens": 32, + "action_max_seq_len": 1024, + # World model predictor (12 blocks, confirmed from checkpoint) + "predictor_depth": 12, +} + +# --------------------------------------------------------------------------- +# Image-key sets (confirmed sources in module docstring) +# --------------------------------------------------------------------------- +# LIBERO — confirmed from lerobot/libero_10 meta/info.json +_LIBERO_CAMS = [ + "observation.images.image", # agentview camera + "observation.images.wrist_image", # eye-in-hand camera +] + +# DROID pretrain — 2 views match the predictor embed_dim=2 × 1024=2048 in checkpoint +_DROID_CAMS = [ + "observation.images.exterior_1_left", + "observation.images.exterior_2_left", +] + +# OXE Bridge + RT1 — single-camera; world model disabled (predictor embed_dim mismatch) +_OXE_CAMS = [ + "observation.images.image", +] + + +# --------------------------------------------------------------------------- +# Config factories +# --------------------------------------------------------------------------- + + +def _build_config(camera_keys: list[str], with_state: bool, enable_world_model: bool = True): + from lerobot.configs.types import FeatureType, PolicyFeature + from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig + + input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)) for k in camera_keys} + if with_state: + input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(8,)) + + cfg = VLAJEPAConfig( + input_features=input_features, + output_features={ + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + }, + enable_world_model=enable_world_model, + **_ARCH, + ) + cfg.validate_features() + return cfg + + +# Maps each subfolder in SOURCE_REPO_ID to (camera_keys, with_state, enable_world_model, repo_suffix) +VARIANTS: dict[str, tuple] = { + "LIBERO": (_LIBERO_CAMS, True, True, "LIBERO"), + "Pretrain": (_DROID_CAMS, False, True, "Pretrain"), + # SimplerEnv uses a single camera; the predictor embed_dim (2048) would mismatch, so + # disable the world model — only qwen + action_model weights are needed for inference. + "SimplerEnv": (_OXE_CAMS, False, False, "SimplerEnv"), +} + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def extract_state_dict(ckpt: object) -> dict[str, torch.Tensor]: + if isinstance(ckpt, dict): + sd = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt.get("model") + if sd is None: + sd = ckpt + else: + sd = ckpt + return {k: v for k, v in sd.items() if isinstance(v, torch.Tensor)} + + +def remap_keys(sd: dict[str, torch.Tensor], remap: dict[str, str]) -> dict[str, torch.Tensor]: + if not remap: + return sd + out = {} + for k, v in sd.items(): + new_k = k + for old, new in remap.items(): + if k.startswith(old): + new_k = new + k[len(old) :] + break + out[new_k] = v + return out + + +def subfolder_of(pt_path: str) -> str | None: + for part in Path(pt_path).parts: + if part in VARIANTS: + return part + return None + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + api = HfApi() + + log.info("Listing .pt files in %s …", SOURCE_REPO_ID) + pt_files = [f for f in api.list_repo_files(SOURCE_REPO_ID) if f.endswith(".pt")] + if not pt_files: + log.error("No .pt files found.") + return + for f in pt_files: + log.info(" %s", f) + + # Create / reuse the collection once + collection = api.create_collection( + title=COLLECTION_TITLE, + description=COLLECTION_DESCRIPTION, + namespace=TARGET_ORG, + exists_ok=True, + ) + log.info("Collection: %s", collection.url) + + for pt_filename in pt_files: + log.info("\n=== %s ===", pt_filename) + + subfolder = subfolder_of(pt_filename) + if subfolder is None: + log.warning(" No variant entry for '%s' — skipping.", pt_filename) + continue + + camera_keys, with_state, enable_world_model, repo_suffix = VARIANTS[subfolder] + target_repo_id = f"{TARGET_ORG}/VLA-JEPA-{repo_suffix}" + + log.info( + " cameras=%d with_state=%s wm=%s → %s", + len(camera_keys), + with_state, + enable_world_model, + target_repo_id, + ) + + # 1. Download + local_pt = api.hf_hub_download(SOURCE_REPO_ID, pt_filename) + log.info(" Downloaded → %s", local_pt) + + # 2. Load checkpoint + try: + ckpt = torch.load(local_pt, map_location="cpu", mmap=True, weights_only=False) # nosec B614 + except TypeError: + ckpt = torch.load(local_pt, map_location="cpu") # nosec B614 + + sd = extract_state_dict(ckpt) + sd = remap_keys(sd, KEY_PREFIX_REMAP) + log.info(" %d tensors extracted", len(sd)) + log.info(" First 5 keys: %s", list(sd)[:5]) + + # 3. Build policy + from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy + + config = _build_config(camera_keys, with_state, enable_world_model) + policy = VLAJEPAPolicy(config) + + # 4. Load weights + missing, unexpected = policy.load_state_dict(sd, strict=False) + + def _prefix_summary(keys: list[str]) -> dict[str, int]: + from collections import Counter + + return dict(Counter(".".join(k.split(".")[:3]) for k in keys).most_common()) + + if missing: + log.warning(" Missing (%d) by prefix: %s", len(missing), _prefix_summary(missing)) + if unexpected: + log.warning(" Unexpected (%d) by prefix: %s", len(unexpected), _prefix_summary(unexpected)) + if not missing and not unexpected: + log.info(" State dict loaded cleanly.") + + # 5. Push to hub (writes model.safetensors + config.json) + api.create_repo(target_repo_id, repo_type="model", exist_ok=True) + commit_url = policy.push_to_hub( + repo_id=target_repo_id, + commit_message=f"Convert {Path(pt_filename).name} to safetensors", + ) + log.info(" Uploaded → %s", commit_url) + + # 6. Add to collection + api.add_collection_item( + collection_slug=collection.slug, + item_id=target_repo_id, + item_type="model", + exists_ok=True, + ) + log.info(" Added to collection.") + + log.info("\nAll done. Collection: %s", collection.url) + + +if __name__ == "__main__": + main()