mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
docs(lingbot_va): trim provenance comments; default wan path to base repo
- configuration_lingbot_va.py: drop the "──" decorations and the "(from transformer/config.json)" note; default wan_pretrained_path to robbyant/lingbot-va-base (has the frozen vae/text_encoder/tokenizer subfolders). - modeling_lingbot_va.py: remove the vendored-code banner and the "(upstream wan_va/...)" section-header provenance/dash decorations; condense the transformer-dtype comment to one line. No code changes. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -37,7 +37,7 @@ from lerobot.utils.constants import ACTION
|
||||
class LingBotVAConfig(PreTrainedConfig):
|
||||
"""Configuration for the native LingBot-VA policy integration in LeRobot."""
|
||||
|
||||
# ── Wan transformer architecture (from transformer/config.json) ──
|
||||
# Wan transformer architecture
|
||||
patch_size: tuple[int, int, int] = (1, 2, 2)
|
||||
num_attention_heads: int = 24
|
||||
attention_head_dim: int = 128
|
||||
@@ -54,15 +54,15 @@ class LingBotVAConfig(PreTrainedConfig):
|
||||
# "flex" = training only (needs recent torch); inference uses "torch" SDPA or "flashattn".
|
||||
attn_mode: str = "torch"
|
||||
|
||||
# ── Frozen sub-models (VAE + UMT5 text encoder + tokenizer) ──
|
||||
# Frozen sub-models (VAE + UMT5 text encoder + tokenizer)
|
||||
# ~20 GB of frozen weights, NOT bundled in the checkpoint; lazily pulled from this HF repo /
|
||||
# local dir (must hold diffusers-style ``vae/``, ``text_encoder/``, ``tokenizer/`` sub-folders).
|
||||
wan_pretrained_path: str = "robbyant/lingbot-va-posttrain-libero-long"
|
||||
wan_pretrained_path: str = "robbyant/lingbot-va-base"
|
||||
dtype: str = "bfloat16" # transformer / VAE / text-encoder dtype: "bfloat16", "float16", "float32"
|
||||
# Frozen UMT5-XXL encoder device; "cpu" frees ~11 GB VRAM (it runs once per episode).
|
||||
text_encoder_device: str = "cpu"
|
||||
|
||||
# ── Observation cameras (order matters: latents are concatenated on width; LIBERO defaults) ──
|
||||
# Observation cameras (order matters: latents are concatenated on width; LIBERO defaults)
|
||||
obs_cam_keys: list[str] = field(
|
||||
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
|
||||
)
|
||||
@@ -72,7 +72,7 @@ class LingBotVAConfig(PreTrainedConfig):
|
||||
# "robotwin_tshape" (full-res head + half-res wrists in a "T"; RoboTwin).
|
||||
camera_layout: str = "width_concat"
|
||||
|
||||
# ── Inference hyperparameters (LIBERO defaults) ──
|
||||
# Inference hyperparameters (LIBERO defaults)
|
||||
n_obs_steps: int = 1
|
||||
height: int = 128
|
||||
width: int = 128
|
||||
@@ -95,8 +95,8 @@ class LingBotVAConfig(PreTrainedConfig):
|
||||
# Opt-in: VAE-decode predicted video latents to ``self.last_predicted_frames`` for saving MP4s.
|
||||
save_predicted_video: bool = False
|
||||
|
||||
# ── Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are
|
||||
# quantile-(un)normalized inside the policy / dedicated processor steps. ──
|
||||
# Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are
|
||||
# quantile-(un)normalized inside the policy / dedicated processor steps.
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
@@ -105,7 +105,7 @@ class LingBotVAConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
# ── Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py) ──
|
||||
# Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py)
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
|
||||
@@ -56,18 +56,8 @@ from lerobot.utils.import_utils import require_package
|
||||
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
|
||||
# ======================================================================================
|
||||
# Vendored Wan2.2 video-action model code (transformer + attention + VAE helpers +
|
||||
# flow-matching scheduler + grid utilities). Adapted from the upstream LingBot-VA repo
|
||||
# (https://github.com/Robbyant/lingbot-va, ``wan_va/``). Per LeRobot convention all model
|
||||
# code for a policy lives in this single ``modeling_*.py`` file. State-dict parameter names
|
||||
# are preserved verbatim so conversion from the original diffusers-style checkpoint is
|
||||
# near-identity. The ``torch`` SDPA attention backend is the default and is always
|
||||
# available; ``flashattn`` and ``flex`` are imported lazily only when selected.
|
||||
# ======================================================================================
|
||||
|
||||
|
||||
# ---- Grid-id / patch utilities (upstream ``wan_va/utils/utils.py``) ------------------
|
||||
# Grid-id / patch utilities
|
||||
def data_seq_to_patch(patch_size, data_seq, latent_num_frames, latent_height, latent_width, batch_size=1):
|
||||
"""Reshape a flattened patch sequence back into a ``(B, C, F, H, W)`` latent grid."""
|
||||
p_t, p_h, p_w = patch_size
|
||||
@@ -100,7 +90,7 @@ def get_mesh_id(f, h, w, t, f_w=1, f_shift=0, action=False):
|
||||
return grid_id
|
||||
|
||||
|
||||
# ---- Flow-matching scheduler (upstream ``wan_va/utils/scheduler.py``) ----------------
|
||||
# Flow-matching scheduler
|
||||
# LingBot-VA uses two independent instances at inference (one for the video-latent stream,
|
||||
# one for the action stream), each with its own ``shift`` and number of denoising steps.
|
||||
class FlowMatchScheduler:
|
||||
@@ -229,7 +219,7 @@ class FlowMatchScheduler:
|
||||
return mu
|
||||
|
||||
|
||||
# ---- Attention backends (upstream ``wan_va/modules/model.py``) -----------------------
|
||||
# Attention backends
|
||||
def custom_sdpa(q, k, v):
|
||||
"""Scaled-dot-product attention operating on ``(B, S, H, D)`` tensors."""
|
||||
out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2))
|
||||
@@ -624,7 +614,7 @@ class WanAttention(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ---- Dual-stream Wan2.2 transformer (upstream ``wan_va/modules/model.py``) ------------
|
||||
# Dual-stream Wan2.2 transformer
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(self, dim, time_freq_dim, time_proj_dim, text_embed_dim, pos_embed_seq_len):
|
||||
super().__init__()
|
||||
@@ -1054,7 +1044,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
return latent_hidden_states
|
||||
|
||||
|
||||
# ---- Wan2.2 VAE helpers (stock diffusers ``AutoencoderKLWan``; upstream wan_va/modules/utils.py) ----
|
||||
# Wan2.2 VAE helpers (stock diffusers ``AutoencoderKLWan``)
|
||||
def _vae_patchify(x, patch_size):
|
||||
if patch_size is None or patch_size == 1:
|
||||
return x
|
||||
@@ -1176,11 +1166,7 @@ class LingBotVAPolicy(PreTrainedPolicy):
|
||||
rope_max_seq_len=config.rope_max_seq_len,
|
||||
attn_mode=config.attn_mode,
|
||||
)
|
||||
# Run the transformer in ``config.dtype`` (bf16 by default): the norm / modulation paths
|
||||
# upcast to fp32 internally (see ``FP32LayerNorm`` + ``.float()`` in ``WanTransformerBlock``),
|
||||
# so a uniform low-precision parameter dtype matches the bf16 video/action latents without
|
||||
# losing numerical stability. Casting here (before ``from_pretrained`` copies the bf16
|
||||
# checkpoint in) keeps weights and activations in the same dtype.
|
||||
# Run the transformer in config.dtype (bf16); norm/modulation paths upcast to fp32 internally.
|
||||
self.transformer = self.transformer.to(self.dtype)
|
||||
|
||||
# Frozen modules are stored OUTSIDE the nn.Module registry (plain dict) so they are
|
||||
@@ -1430,16 +1416,16 @@ class LingBotVAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
cfg = self.config
|
||||
device = cfg.device
|
||||
# ---- text embeddings ----
|
||||
# text embeddings
|
||||
task = batch.get("task")
|
||||
if isinstance(task, str):
|
||||
task = [task]
|
||||
text_emb = self._get_t5_prompt_embeds(list(task), cfg.max_sequence_length)
|
||||
|
||||
# ---- video latents (VAE-encode the camera clips) ----
|
||||
# video latents (VAE-encode the camera clips)
|
||||
latents = self._encode_training_latents(batch)
|
||||
|
||||
# ---- actions -> [B, action_dim, F, action_per_frame, 1] ----
|
||||
# actions -> [B, action_dim, F, action_per_frame, 1]
|
||||
act = batch[ACTION].to(device) # [B, F*apf, n_used]
|
||||
B = act.shape[0]
|
||||
used = cfg.used_action_channel_ids
|
||||
|
||||
Reference in New Issue
Block a user