docs(lingbot_va): trim provenance comments; default wan path to base repo

- configuration_lingbot_va.py: drop the "──" decorations and the
  "(from transformer/config.json)" note; default wan_pretrained_path to
  robbyant/lingbot-va-base (has the frozen vae/text_encoder/tokenizer subfolders).
- modeling_lingbot_va.py: remove the vendored-code banner and the
  "(upstream wan_va/...)" section-header provenance/dash decorations; condense the
  transformer-dtype comment to one line.

No code changes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-08 11:47:45 +02:00
parent f617b2c2bf
commit 8e692e365c
2 changed files with 17 additions and 31 deletions
@@ -37,7 +37,7 @@ from lerobot.utils.constants import ACTION
class LingBotVAConfig(PreTrainedConfig):
"""Configuration for the native LingBot-VA policy integration in LeRobot."""
# ── Wan transformer architecture (from transformer/config.json) ──
# Wan transformer architecture
patch_size: tuple[int, int, int] = (1, 2, 2)
num_attention_heads: int = 24
attention_head_dim: int = 128
@@ -54,15 +54,15 @@ class LingBotVAConfig(PreTrainedConfig):
# "flex" = training only (needs recent torch); inference uses "torch" SDPA or "flashattn".
attn_mode: str = "torch"
# ── Frozen sub-models (VAE + UMT5 text encoder + tokenizer) ──
# Frozen sub-models (VAE + UMT5 text encoder + tokenizer)
# ~20 GB of frozen weights, NOT bundled in the checkpoint; lazily pulled from this HF repo /
# local dir (must hold diffusers-style ``vae/``, ``text_encoder/``, ``tokenizer/`` sub-folders).
wan_pretrained_path: str = "robbyant/lingbot-va-posttrain-libero-long"
wan_pretrained_path: str = "robbyant/lingbot-va-base"
dtype: str = "bfloat16" # transformer / VAE / text-encoder dtype: "bfloat16", "float16", "float32"
# Frozen UMT5-XXL encoder device; "cpu" frees ~11 GB VRAM (it runs once per episode).
text_encoder_device: str = "cpu"
# ── Observation cameras (order matters: latents are concatenated on width; LIBERO defaults) ──
# Observation cameras (order matters: latents are concatenated on width; LIBERO defaults)
obs_cam_keys: list[str] = field(
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
)
@@ -72,7 +72,7 @@ class LingBotVAConfig(PreTrainedConfig):
# "robotwin_tshape" (full-res head + half-res wrists in a "T"; RoboTwin).
camera_layout: str = "width_concat"
# ── Inference hyperparameters (LIBERO defaults) ──
# Inference hyperparameters (LIBERO defaults)
n_obs_steps: int = 1
height: int = 128
width: int = 128
@@ -95,8 +95,8 @@ class LingBotVAConfig(PreTrainedConfig):
# Opt-in: VAE-decode predicted video latents to ``self.last_predicted_frames`` for saving MP4s.
save_predicted_video: bool = False
# ── Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are
# quantile-(un)normalized inside the policy / dedicated processor steps. ──
# Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are
# quantile-(un)normalized inside the policy / dedicated processor steps.
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
@@ -105,7 +105,7 @@ class LingBotVAConfig(PreTrainedConfig):
}
)
# ── Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py) ──
# Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py)
optimizer_lr: float = 1e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
@@ -56,18 +56,8 @@ from lerobot.utils.import_utils import require_package
from .configuration_lingbot_va import LingBotVAConfig
# ======================================================================================
# Vendored Wan2.2 video-action model code (transformer + attention + VAE helpers +
# flow-matching scheduler + grid utilities). Adapted from the upstream LingBot-VA repo
# (https://github.com/Robbyant/lingbot-va, ``wan_va/``). Per LeRobot convention all model
# code for a policy lives in this single ``modeling_*.py`` file. State-dict parameter names
# are preserved verbatim so conversion from the original diffusers-style checkpoint is
# near-identity. The ``torch`` SDPA attention backend is the default and is always
# available; ``flashattn`` and ``flex`` are imported lazily only when selected.
# ======================================================================================
# ---- Grid-id / patch utilities (upstream ``wan_va/utils/utils.py``) ------------------
# Grid-id / patch utilities
def data_seq_to_patch(patch_size, data_seq, latent_num_frames, latent_height, latent_width, batch_size=1):
"""Reshape a flattened patch sequence back into a ``(B, C, F, H, W)`` latent grid."""
p_t, p_h, p_w = patch_size
@@ -100,7 +90,7 @@ def get_mesh_id(f, h, w, t, f_w=1, f_shift=0, action=False):
return grid_id
# ---- Flow-matching scheduler (upstream ``wan_va/utils/scheduler.py``) ----------------
# Flow-matching scheduler
# LingBot-VA uses two independent instances at inference (one for the video-latent stream,
# one for the action stream), each with its own ``shift`` and number of denoising steps.
class FlowMatchScheduler:
@@ -229,7 +219,7 @@ class FlowMatchScheduler:
return mu
# ---- Attention backends (upstream ``wan_va/modules/model.py``) -----------------------
# Attention backends
def custom_sdpa(q, k, v):
"""Scaled-dot-product attention operating on ``(B, S, H, D)`` tensors."""
out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2))
@@ -624,7 +614,7 @@ class WanAttention(nn.Module):
return hidden_states
# ---- Dual-stream Wan2.2 transformer (upstream ``wan_va/modules/model.py``) ------------
# Dual-stream Wan2.2 transformer
class WanTimeTextImageEmbedding(nn.Module):
def __init__(self, dim, time_freq_dim, time_proj_dim, text_embed_dim, pos_embed_seq_len):
super().__init__()
@@ -1054,7 +1044,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin):
return latent_hidden_states
# ---- Wan2.2 VAE helpers (stock diffusers ``AutoencoderKLWan``; upstream wan_va/modules/utils.py) ----
# Wan2.2 VAE helpers (stock diffusers ``AutoencoderKLWan``)
def _vae_patchify(x, patch_size):
if patch_size is None or patch_size == 1:
return x
@@ -1176,11 +1166,7 @@ class LingBotVAPolicy(PreTrainedPolicy):
rope_max_seq_len=config.rope_max_seq_len,
attn_mode=config.attn_mode,
)
# Run the transformer in ``config.dtype`` (bf16 by default): the norm / modulation paths
# upcast to fp32 internally (see ``FP32LayerNorm`` + ``.float()`` in ``WanTransformerBlock``),
# so a uniform low-precision parameter dtype matches the bf16 video/action latents without
# losing numerical stability. Casting here (before ``from_pretrained`` copies the bf16
# checkpoint in) keeps weights and activations in the same dtype.
# Run the transformer in config.dtype (bf16); norm/modulation paths upcast to fp32 internally.
self.transformer = self.transformer.to(self.dtype)
# Frozen modules are stored OUTSIDE the nn.Module registry (plain dict) so they are
@@ -1430,16 +1416,16 @@ class LingBotVAPolicy(PreTrainedPolicy):
"""
cfg = self.config
device = cfg.device
# ---- text embeddings ----
# text embeddings
task = batch.get("task")
if isinstance(task, str):
task = [task]
text_emb = self._get_t5_prompt_embeds(list(task), cfg.max_sequence_length)
# ---- video latents (VAE-encode the camera clips) ----
# video latents (VAE-encode the camera clips)
latents = self._encode_training_latents(batch)
# ---- actions -> [B, action_dim, F, action_per_frame, 1] ----
# actions -> [B, action_dim, F, action_per_frame, 1]
act = batch[ACTION].to(device) # [B, F*apf, n_used]
B = act.shape[0]
used = cfg.used_action_channel_ids