diff --git a/src/lerobot/policies/lingbot_va/configuration_lingbot_va.py b/src/lerobot/policies/lingbot_va/configuration_lingbot_va.py index 5cb3a2341..63ce52b64 100644 --- a/src/lerobot/policies/lingbot_va/configuration_lingbot_va.py +++ b/src/lerobot/policies/lingbot_va/configuration_lingbot_va.py @@ -37,7 +37,7 @@ from lerobot.utils.constants import ACTION class LingBotVAConfig(PreTrainedConfig): """Configuration for the native LingBot-VA policy integration in LeRobot.""" - # ── Wan transformer architecture (from transformer/config.json) ── + # Wan transformer architecture patch_size: tuple[int, int, int] = (1, 2, 2) num_attention_heads: int = 24 attention_head_dim: int = 128 @@ -54,15 +54,15 @@ class LingBotVAConfig(PreTrainedConfig): # "flex" = training only (needs recent torch); inference uses "torch" SDPA or "flashattn". attn_mode: str = "torch" - # ── Frozen sub-models (VAE + UMT5 text encoder + tokenizer) ── + # Frozen sub-models (VAE + UMT5 text encoder + tokenizer) # ~20 GB of frozen weights, NOT bundled in the checkpoint; lazily pulled from this HF repo / # local dir (must hold diffusers-style ``vae/``, ``text_encoder/``, ``tokenizer/`` sub-folders). - wan_pretrained_path: str = "robbyant/lingbot-va-posttrain-libero-long" + wan_pretrained_path: str = "robbyant/lingbot-va-base" dtype: str = "bfloat16" # transformer / VAE / text-encoder dtype: "bfloat16", "float16", "float32" # Frozen UMT5-XXL encoder device; "cpu" frees ~11 GB VRAM (it runs once per episode). text_encoder_device: str = "cpu" - # ── Observation cameras (order matters: latents are concatenated on width; LIBERO defaults) ── + # Observation cameras (order matters: latents are concatenated on width; LIBERO defaults) obs_cam_keys: list[str] = field( default_factory=lambda: ["observation.images.image", "observation.images.image2"] ) @@ -72,7 +72,7 @@ class LingBotVAConfig(PreTrainedConfig): # "robotwin_tshape" (full-res head + half-res wrists in a "T"; RoboTwin). camera_layout: str = "width_concat" - # ── Inference hyperparameters (LIBERO defaults) ── + # Inference hyperparameters (LIBERO defaults) n_obs_steps: int = 1 height: int = 128 width: int = 128 @@ -95,8 +95,8 @@ class LingBotVAConfig(PreTrainedConfig): # Opt-in: VAE-decode predicted video latents to ``self.last_predicted_frames`` for saving MP4s. save_predicted_video: bool = False - # ── Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are - # quantile-(un)normalized inside the policy / dedicated processor steps. ── + # Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are + # quantile-(un)normalized inside the policy / dedicated processor steps. normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.IDENTITY, @@ -105,7 +105,7 @@ class LingBotVAConfig(PreTrainedConfig): } ) - # ── Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py) ── + # Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py) optimizer_lr: float = 1e-5 optimizer_betas: tuple[float, float] = (0.9, 0.95) optimizer_eps: float = 1e-8 diff --git a/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py b/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py index dae563f91..b1040283c 100644 --- a/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py +++ b/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py @@ -56,18 +56,8 @@ from lerobot.utils.import_utils import require_package from .configuration_lingbot_va import LingBotVAConfig -# ====================================================================================== -# Vendored Wan2.2 video-action model code (transformer + attention + VAE helpers + -# flow-matching scheduler + grid utilities). Adapted from the upstream LingBot-VA repo -# (https://github.com/Robbyant/lingbot-va, ``wan_va/``). Per LeRobot convention all model -# code for a policy lives in this single ``modeling_*.py`` file. State-dict parameter names -# are preserved verbatim so conversion from the original diffusers-style checkpoint is -# near-identity. The ``torch`` SDPA attention backend is the default and is always -# available; ``flashattn`` and ``flex`` are imported lazily only when selected. -# ====================================================================================== - -# ---- Grid-id / patch utilities (upstream ``wan_va/utils/utils.py``) ------------------ +# Grid-id / patch utilities def data_seq_to_patch(patch_size, data_seq, latent_num_frames, latent_height, latent_width, batch_size=1): """Reshape a flattened patch sequence back into a ``(B, C, F, H, W)`` latent grid.""" p_t, p_h, p_w = patch_size @@ -100,7 +90,7 @@ def get_mesh_id(f, h, w, t, f_w=1, f_shift=0, action=False): return grid_id -# ---- Flow-matching scheduler (upstream ``wan_va/utils/scheduler.py``) ---------------- +# Flow-matching scheduler # LingBot-VA uses two independent instances at inference (one for the video-latent stream, # one for the action stream), each with its own ``shift`` and number of denoising steps. class FlowMatchScheduler: @@ -229,7 +219,7 @@ class FlowMatchScheduler: return mu -# ---- Attention backends (upstream ``wan_va/modules/model.py``) ----------------------- +# Attention backends def custom_sdpa(q, k, v): """Scaled-dot-product attention operating on ``(B, S, H, D)`` tensors.""" out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)) @@ -624,7 +614,7 @@ class WanAttention(nn.Module): return hidden_states -# ---- Dual-stream Wan2.2 transformer (upstream ``wan_va/modules/model.py``) ------------ +# Dual-stream Wan2.2 transformer class WanTimeTextImageEmbedding(nn.Module): def __init__(self, dim, time_freq_dim, time_proj_dim, text_embed_dim, pos_embed_seq_len): super().__init__() @@ -1054,7 +1044,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin): return latent_hidden_states -# ---- Wan2.2 VAE helpers (stock diffusers ``AutoencoderKLWan``; upstream wan_va/modules/utils.py) ---- +# Wan2.2 VAE helpers (stock diffusers ``AutoencoderKLWan``) def _vae_patchify(x, patch_size): if patch_size is None or patch_size == 1: return x @@ -1176,11 +1166,7 @@ class LingBotVAPolicy(PreTrainedPolicy): rope_max_seq_len=config.rope_max_seq_len, attn_mode=config.attn_mode, ) - # Run the transformer in ``config.dtype`` (bf16 by default): the norm / modulation paths - # upcast to fp32 internally (see ``FP32LayerNorm`` + ``.float()`` in ``WanTransformerBlock``), - # so a uniform low-precision parameter dtype matches the bf16 video/action latents without - # losing numerical stability. Casting here (before ``from_pretrained`` copies the bf16 - # checkpoint in) keeps weights and activations in the same dtype. + # Run the transformer in config.dtype (bf16); norm/modulation paths upcast to fp32 internally. self.transformer = self.transformer.to(self.dtype) # Frozen modules are stored OUTSIDE the nn.Module registry (plain dict) so they are @@ -1430,16 +1416,16 @@ class LingBotVAPolicy(PreTrainedPolicy): """ cfg = self.config device = cfg.device - # ---- text embeddings ---- + # text embeddings task = batch.get("task") if isinstance(task, str): task = [task] text_emb = self._get_t5_prompt_embeds(list(task), cfg.max_sequence_length) - # ---- video latents (VAE-encode the camera clips) ---- + # video latents (VAE-encode the camera clips) latents = self._encode_training_latents(batch) - # ---- actions -> [B, action_dim, F, action_per_frame, 1] ---- + # actions -> [B, action_dim, F, action_per_frame, 1] act = batch[ACTION].to(device) # [B, F*apf, n_used] B = act.shape[0] used = cfg.used_action_channel_ids