feat(policies): add LingBot-VA autoregressive video-action world model

Port the LingBot-VA policy (Wan2.2 dual-stream video+action world model) into
LeRobot, following the EO-1 / VLA-JEPA conventions. Covers inference, checkpoint
conversion, and predicted-video saving (training is deferred to a follow-up PR).

- Vendored Wan transformer/attention/flex/VAE/scheduler modules (key names preserved
  for near-identity conversion); torch SDPA default, flashattn/flex lazy-guarded.
- LingBotVAConfig (registered "lingbot_va") + processor with fixed-quantile action
  unnormalization; full dual-stream sampling loop with CFG, two flow-matching
  schedulers and KV cache, mapped onto select_action with observed-keyframe feedback.
- convert_lingbot_va_checkpoints.py (libero/robotwin variants): bundles the ~5B
  transformer, lazy-pulls the frozen VAE+UMT5 from the source repo.
- Predicted-video plumbing in lerobot_eval (predicted_frames_callback; opt-in via
  --policy.save_predicted_video) and ConstantWithWarmupSchedulerConfig.
- pyproject: widen diffusers-dep to <0.37, add lingbot_va + imageio-dep extras,
  add lingbot_va and (missing) eo1 to `all`.
- Factory + policies/__init__ wiring, docs page + toctree, and tests.

Note: the LIBERO success-rate correctness gate must be validated on a CUDA GPU
with the converted checkpoint.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-05 16:28:19 +02:00
parent 2e9cd87bbd
commit 4dfa8cea65
23 changed files with 3031 additions and 1 deletions
+2
View File
@@ -67,6 +67,8 @@
title: VLA-JEPA
- local: eo1
title: EO-1
- local: lingbot_va
title: LingBot-VA
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
+120
View File
@@ -0,0 +1,120 @@
# LingBot-VA
LingBot-VA is an **autoregressive video-action world-model policy** built on the **Wan2.2**
video-diffusion stack. It interleaves, in one autoregressive sequence, the prediction of
future **video latents** and **robot actions** ("VA" = Video-Action). The LeRobot
integration wires LingBot-VA into the standard training, evaluation and processor
interfaces.
## Model Overview
LingBot-VA is a **dual-stream "mixture-of-transformers"**: a video/latent stream
(`patch_embedding_mlp → blocks → proj_out`) and an action stream
(`action_embedder → blocks → action_proj_out`) share the same 30 transformer blocks and
text conditioning. Actions are produced by the dedicated `action_proj_out` head — they are
**not** decoded from predicted pixels, though video and action are co-trained.
| Component | Class | Role |
|---|---|---|
| DiT backbone (trainable) | `WanTransformer3DModel` | ~5B-param dual-stream transformer (the only weights stored in the LeRobot checkpoint). |
| VAE (frozen) | `AutoencoderKLWan` | Wan2.2 VAE, `z_dim=48`. Lazy-pulled from the source repo. |
| Text encoder (frozen) | `UMT5EncoderModel` | UMT5-XXL, `d_model=4096`. Lazy-pulled from the source repo. |
At inference the policy runs an autoregressive loop per chunk: it denoises the video-latent
stream (CFG, ~20 steps) and the action stream (~50 steps) with two independent
flow-matching schedulers, maintaining a KV cache across chunks. Real observed keyframes are
fed back into the KV cache as the chunk is executed (closed-loop world modeling).
### What the LeRobot Integration Covers
- Standard `policy.type=lingbot_va` configuration through LeRobot.
- Checkpoint conversion from the released HuggingFace checkpoints.
- Autoregressive dual-stream inference behind the standard `select_action` interface
(single-environment eval, `--eval.batch_size=1`).
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
- Evaluation with `lerobot-eval` on the LIBERO benchmark.
Training (the flow-matching dual-stream loss + latent dataset) is part of a follow-up
training port and is not yet wired into `lerobot-train`.
## Installation
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install the LingBot-VA extra (brings in `diffusers>=0.36` for the Wan2.2 stack):
```bash
pip install -e ".[lingbot_va]"
# For LIBERO evaluation (Linux only):
pip install -e ".[lingbot_va,libero]"
```
## Checkpoint Conversion
The released checkpoints are diffusers-style directories
(`robbyant/lingbot-va-base`, `robbyant/lingbot-va-posttrain-robotwin`,
`robbyant/lingbot-va-posttrain-libero-long`). Convert one to LeRobot format with:
```bash
python -m lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints \
--checkpoint robbyant/lingbot-va-posttrain-libero-long \
--variant libero \
--output_dir outputs/lingbot_va_libero_long
```
**Packaging:** only the trainable ~5B transformer is stored in the LeRobot
`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are **lazily pulled** from
`config.wan_pretrained_path` at load time (defaults to the source repo). Pass
`--bundle-frozen` to copy those sub-folders next to the converted checkpoint instead.
Run conversion on a Linux machine with a CUDA GPU and enough RAM/VRAM to materialize the
transformer.
## Evaluation (LIBERO)
```bash
lerobot-eval \
--policy.path=outputs/lingbot_va_libero_long \
--env.type=libero --env.task=libero_10 \
--eval.n_episodes=50 --eval.batch_size=1 \
--output_dir=outputs/eval/lingbot_va_libero
```
LingBot-VA's streaming inference (KV cache + observed-keyframe feedback) is implemented for
single-environment eval; use `--eval.batch_size=1`.
### Saving predicted (imagined) videos
Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicted video
latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos.
The same flag works for the periodic eval during `lerobot-train`.
## Inference Hyperparameters (LIBERO)
| Key | Value |
|---|---|
| height × width | 128 × 128 |
| cameras | `observation.images.image` (agentview), `observation.images.image2` (eye-in-hand) |
| action channels used | 06 (7-DoF arm + gripper) |
| action_per_frame / frame_chunk_size | 4 / 4 |
| attn_window | 30 |
| video / action denoising steps | 20 / 50 |
| guidance_scale / action_guidance_scale | 5 / 1 |
| snr_shift / action_snr_shift | 5.0 / 0.05 |
These are the defaults of `LingBotVAConfig`; override any of them via `--policy.<name>=...`.
## Notes & Limitations
- **Correctness gate:** matching the upstream LIBERO success rate requires validating the
converted checkpoint on a GPU and tensor-diffing intermediate activations against the
upstream implementation. The most sensitive parts are the action quantile normalization,
the camera ordering, the `action_per_frame`/`frame_chunk_size` alignment, and `attn_mode`.
- **Attention backend:** inference uses the `torch` SDPA backend (always available). The
`flashattn` and `flex` backends are optional; `flex` is only needed for training.
- **Model size:** the DiT is ~5B params and the frozen VAE+UMT5 add ~20 GB; inference needs
roughly 1824 GB of VRAM.
## License
LingBot-VA is released under Apache-2.0. See the
[upstream repository](https://github.com/Robbyant/lingbot-va).
+11 -1
View File
@@ -146,7 +146,8 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
diffusers-dep = ["diffusers>=0.27.2,<0.37.0"]
imageio-dep = ["imageio[ffmpeg]>=2.34.0,<3.0.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
pyserial-dep = ["pyserial>=3.5,<4.0"]
@@ -218,6 +219,10 @@ xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
# LingBot-VA needs the Wan2.2 stack (AutoencoderKLWan z_dim=48 + WanTransformer3DModel config schema),
# which only exists in diffusers>=0.36. Pin the floor explicitly so a standalone `lerobot[lingbot_va]`
# install can't resolve to a pre-Wan2.2 diffusers via the looser diffusers-dep floor.
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
@@ -284,6 +289,8 @@ all = [
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[eo1]",
"lerobot[lingbot_va]",
"lerobot[async]",
"lerobot[dev]",
"lerobot[test]",
@@ -375,6 +382,9 @@ ignore = [
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
# Vendored Wan2.2 / LingBot-VA model code uses tensor-dimension names (B, F, H, W) and `F` for
# torch.nn.functional; keep the upstream naming to make diffing against upstream tractable.
"src/lerobot/policies/lingbot_va/**" = ["N803", "N806", "N812", "SIM102"]
[tool.ruff.lint.isort]
combine-as-imports = true
+22
View File
@@ -83,6 +83,28 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("constant_with_warmup")
@dataclass
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Linear warmup followed by a constant learning rate.
Mirrors the ``warmup_constant_lambda`` used by LingBot-VA (upstream ``wan_va/train.py``):
the LR ramps linearly from 0 to the peak over ``num_warmup_steps`` steps, then stays flat.
"""
num_warmup_steps: int = 1000
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
warmup_steps = self.num_warmup_steps or 0
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return 1.0
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
+2
View File
@@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig as LingBotVAConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
@@ -44,6 +45,7 @@ __all__ = [
"EO1Config",
"GaussianActorConfig",
"GrootConfig",
"LingBotVAConfig",
"MolmoAct2Config",
"MultiTaskDiTConfig",
"PI0Config",
+15
View File
@@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
@@ -162,6 +163,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
elif name == "lingbot_va":
from .lingbot_va.modeling_lingbot_va import LingBotVAPolicy
return LingBotVAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -218,6 +223,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
elif policy_type == "lingbot_va":
return LingBotVAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -448,6 +455,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, LingBotVAConfig):
from .lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
processors = make_lingbot_va_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/lingbot_va.mdx
@@ -0,0 +1,33 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: ``LingBotVAPolicy`` (and the Wan transformer it owns) imports ``diffusers`` as a
# hard dependency at class-definition time (it subclasses diffusers' ModelMixin/ConfigMixin).
# To keep base ``import lerobot`` working without the optional ``lingbot_va`` extra, the
# policy is exposed lazily via module ``__getattr__`` — the heavy import only happens when
# ``LingBotVAPolicy`` is actually accessed (mirroring the lazy import in policies/factory.py).
from .configuration_lingbot_va import LingBotVAConfig
from .processor_lingbot_va import make_lingbot_va_pre_post_processors
__all__ = ["LingBotVAConfig", "LingBotVAPolicy", "make_lingbot_va_pre_post_processors"]
def __getattr__(name):
if name == "LingBotVAPolicy":
from .modeling_lingbot_va import LingBotVAPolicy
return LingBotVAPolicy
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,198 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for the LingBot-VA policy.
LingBot-VA is an autoregressive video-action world-model policy built on the Wan2.2
video-diffusion stack. It interleaves prediction of future video latents and robot
actions in a single dual-stream transformer. See ``docs/source/lingbot_va.mdx`` and the
upstream repository (https://github.com/Robbyant/lingbot-va).
Defaults below match the upstream LIBERO configuration (``wan_va/configs/va_libero_cfg.py``)
and the ``transformer/config.json`` of the released checkpoints.
"""
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION
# Upstream LIBERO action-normalization quantiles (single 7-DoF arm + gripper).
# Verbatim from wan_va/configs/va_libero_cfg.py (channels 0-6 of a 30-dim action space).
LIBERO_ACTION_Q01 = [
-0.6589285731315613,
-0.84375,
-0.9375,
-0.12107142806053162,
-0.15964286029338837,
-0.26571428775787354,
-1.0,
]
LIBERO_ACTION_Q99 = [
0.8999999761581421,
0.8544642925262451,
0.9375,
0.17142857611179352,
0.1842857152223587,
0.34392857551574707,
1.0,
]
@PreTrainedConfig.register_subclass("lingbot_va")
@dataclass
class LingBotVAConfig(PreTrainedConfig):
"""Configuration for the native LingBot-VA policy integration in LeRobot."""
# ── Wan transformer architecture (from transformer/config.json) ──
patch_size: tuple[int, int, int] = (1, 2, 2)
num_attention_heads: int = 24
attention_head_dim: int = 128
in_channels: int = 48
out_channels: int = 48
action_dim: int = 30
text_dim: int = 4096
freq_dim: int = 256
ffn_dim: int = 14336
num_layers: int = 30
cross_attn_norm: bool = True
eps: float = 1e-6
rope_max_seq_len: int = 1024
# "flex" is supported for training only and needs a recent torch build. Inference uses
# "torch" SDPA (always available) or, optionally, "flashattn".
attn_mode: str = "torch"
# ── Frozen sub-models (VAE + UMT5 text encoder + tokenizer) ──
# These heavy frozen weights (~20 GB) are NOT bundled into the LeRobot safetensors
# checkpoint (only the trainable ~5B transformer is). They are lazily pulled from this
# HF repo / local directory at policy-init time. The directory must contain the
# diffusers-style ``vae/``, ``text_encoder/`` and ``tokenizer/`` sub-folders.
wan_pretrained_path: str = "robbyant/lingbot-va-posttrain-libero-long"
# dtype used for the transformer / VAE / text-encoder weights at inference.
dtype: str = "bfloat16" # one of "bfloat16", "float16", "float32"
# ── Observation cameras (order matters: latents are concatenated on width) ──
# Defaults match the LIBERO env feature keys (agentview -> image, eye-in-hand -> image2).
obs_cam_keys: list[str] = field(
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
)
# ── Inference hyperparameters (LIBERO defaults) ──
n_obs_steps: int = 1
height: int = 128
width: int = 128
action_per_frame: int = 4
frame_chunk_size: int = 4
attn_window: int = 30
num_inference_steps: int = 20
video_exec_step: int = -1
action_num_inference_steps: int = 50
guidance_scale: float = 5.0
action_guidance_scale: float = 1.0
snr_shift: float = 5.0
action_snr_shift: float = 0.05
max_sequence_length: int = 512 # UMT5 prompt length
# Subset of the 30-d action space actually used by the benchmark (LIBERO = 7-DoF).
used_action_channel_ids: list[int] = field(default_factory=lambda: list(range(7)))
# Fixed quantiles for action (un)normalization on the *used* channels.
action_q01: list[float] = field(default_factory=lambda: list(LIBERO_ACTION_Q01))
action_q99: list[float] = field(default_factory=lambda: list(LIBERO_ACTION_Q99))
# Opt-in: VAE-decode the predicted video latents and stash them on
# ``self.last_predicted_frames`` so eval/train can save predicted-video MP4s.
save_predicted_video: bool = False
# ── Normalization (handled internally / via custom steps, hence IDENTITY here) ──
# Images are scaled to [-1, 1] and VAE-encoded inside the policy; actions are
# quantile-(un)normalized by dedicated processor steps using the fixed quantiles above.
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
)
# ── Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py) ──
optimizer_lr: float = 1e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-4
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 1000
def __post_init__(self):
super().__post_init__()
if self.attn_mode not in ("torch", "flashattn", "flex"):
raise ValueError(f"attn_mode must be one of 'torch', 'flashattn', 'flex'; got {self.attn_mode!r}")
if len(self.action_q01) != len(self.used_action_channel_ids) or len(self.action_q99) != len(
self.used_action_channel_ids
):
raise ValueError(
"action_q01 / action_q99 must each have one entry per used_action_channel_ids "
f"({len(self.used_action_channel_ids)}); got {len(self.action_q01)} / {len(self.action_q99)}."
)
@property
def chunk_size(self) -> int:
"""Number of single-step actions produced per autoregressive chunk."""
return self.frame_chunk_size * self.action_per_frame
@property
def n_action_steps(self) -> int:
"""Number of actions executed before refilling (the whole chunk)."""
return self.chunk_size
def validate_features(self) -> None:
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"LingBot-VA requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if ACTION not in self.output_features:
self.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION, shape=(len(self.used_action_channel_ids),)
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
# Upstream uses a linear warmup followed by a constant LR (warmup_constant_lambda).
from lerobot.optim.schedulers import ConstantWithWarmupSchedulerConfig
return ConstantWithWarmupSchedulerConfig(num_warmup_steps=self.scheduler_warmup_steps)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
@@ -0,0 +1,256 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a released LingBot-VA HuggingFace checkpoint to LeRobot format.
The released checkpoints are diffusers-style directories with ``transformer/``, ``vae/``,
``text_encoder/`` and ``tokenizer/`` sub-folders. This script:
1. loads the (sharded) ``transformer/`` weights with the vendored ``WanTransformer3DModel``;
2. builds a :class:`LingBotVAConfig` for the target benchmark variant;
3. instantiates a :class:`LingBotVAPolicy` and copies the transformer weights into it
(near-identity: the only key change is the ``transformer.`` prefix);
4. saves the LeRobot policy (``model.safetensors`` + ``config.json``) and its processors.
Packaging decision: only the trainable ~5B transformer is bundled into the LeRobot
``model.safetensors``. The frozen VAE + UMT5 text encoder + tokenizer (~20 GB) are NOT
copied; instead ``config.wan_pretrained_path`` records where to lazily pull them from at
load time (defaults to the source repo/dir). Pass ``--bundle-frozen`` to additionally copy
those sub-folders next to the converted checkpoint and point ``wan_pretrained_path`` at it.
Example (LIBERO-Long, the LIBERO eval gate):
python -m lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints \
--checkpoint robbyant/lingbot-va-posttrain-libero-long \
--variant libero \
--output_dir outputs/lingbot_va_libero_long
Requires a CUDA GPU with enough RAM/VRAM to materialize the transformer; run on Linux.
"""
import argparse
import shutil
from pathlib import Path
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
from lerobot.policies.lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
from lerobot.policies.lingbot_va.wan_transformer import WanTransformer3DModel
from lerobot.utils.constants import ACTION, OBS_IMAGES
# Per-benchmark variant presets (camera keys + action layout). Values mirror the upstream
# configs (wan_va/configs/va_*_cfg.py).
VARIANTS = {
"libero": {
"obs_cam_keys": [f"{OBS_IMAGES}.image", f"{OBS_IMAGES}.image2"],
"height": 128,
"width": 128,
"action_per_frame": 4,
"frame_chunk_size": 4,
"attn_window": 30,
"num_inference_steps": 20,
"action_num_inference_steps": 50,
"guidance_scale": 5.0,
"action_guidance_scale": 1.0,
"snr_shift": 5.0,
"action_snr_shift": 0.05,
"used_action_channel_ids": list(range(7)),
# 7-DoF: agentview + eye-in-hand, single arm. Quantiles are the config defaults.
"image_shape": (3, 256, 256),
},
"robotwin": {
"obs_cam_keys": [
f"{OBS_IMAGES}.cam_high",
f"{OBS_IMAGES}.cam_left_wrist",
f"{OBS_IMAGES}.cam_right_wrist",
],
"height": 256,
"width": 320,
"action_per_frame": 16,
"frame_chunk_size": 2,
"attn_window": 72,
"num_inference_steps": 25,
"action_num_inference_steps": 50,
"guidance_scale": 5.0,
"action_guidance_scale": 1.0,
"snr_shift": 5.0,
"action_snr_shift": 1.0,
# RoboTwin is dual-arm; set the used channels / quantiles to match the deployed config.
"used_action_channel_ids": list(range(14)),
"image_shape": (3, 256, 256),
},
}
def _transformer_dir(checkpoint: str) -> str:
"""Return the path/repo that ``WanTransformer3DModel.from_pretrained`` should read."""
p = Path(checkpoint)
if p.is_dir():
return str(p / "transformer")
return checkpoint # HF repo id; use subfolder kwarg below
def load_source_transformer(checkpoint: str, dtype: torch.dtype) -> WanTransformer3DModel:
p = Path(checkpoint)
if p.is_dir():
return WanTransformer3DModel.from_pretrained(
str(p / "transformer"), torch_dtype=dtype, attn_mode="torch"
)
return WanTransformer3DModel.from_pretrained(
checkpoint, subfolder="transformer", torch_dtype=dtype, attn_mode="torch"
)
def build_config(variant: str, wan_pretrained_path: str, dtype: str) -> LingBotVAConfig:
preset = VARIANTS[variant]
n_used = len(preset["used_action_channel_ids"])
kwargs = {
"wan_pretrained_path": wan_pretrained_path,
"dtype": dtype,
"obs_cam_keys": preset["obs_cam_keys"],
"height": preset["height"],
"width": preset["width"],
"action_per_frame": preset["action_per_frame"],
"frame_chunk_size": preset["frame_chunk_size"],
"attn_window": preset["attn_window"],
"num_inference_steps": preset["num_inference_steps"],
"action_num_inference_steps": preset["action_num_inference_steps"],
"guidance_scale": preset["guidance_scale"],
"action_guidance_scale": preset["action_guidance_scale"],
"snr_shift": preset["snr_shift"],
"action_snr_shift": preset["action_snr_shift"],
"used_action_channel_ids": preset["used_action_channel_ids"],
"device": "cpu",
}
if variant != "libero":
# LIBERO keeps the config default quantiles; other variants need their own. Until the
# exact per-channel quantiles are wired in, use a neutral [-1, 1] mapping (no rescale).
kwargs["action_q01"] = [-1.0] * n_used
kwargs["action_q99"] = [1.0] * n_used
cfg = LingBotVAConfig(**kwargs)
# Populate input/output features (cameras + action) so validate_features passes.
img_shape = preset["image_shape"]
cfg.input_features = {
k: PolicyFeature(type=FeatureType.VISUAL, shape=img_shape) for k in preset["obs_cam_keys"]
}
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(n_used,))}
cfg.validate_features()
return cfg
def convert(
checkpoint: str, variant: str, output_dir: str, dtype: str, bundle_frozen: bool, push_to: str | None
):
torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[dtype]
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
# Decide where frozen modules will be pulled from at load time.
if bundle_frozen:
wan_pretrained_path = str(out)
_copy_frozen_subfolders(checkpoint, out)
else:
wan_pretrained_path = checkpoint
print(f"Building LingBot-VA config for variant '{variant}' (frozen modules from: {wan_pretrained_path})")
cfg = build_config(variant, wan_pretrained_path, dtype)
print("Loading source transformer weights ...")
src = load_source_transformer(checkpoint, torch_dtype)
src_sd = src.state_dict()
print("Instantiating LingBotVAPolicy and copying transformer weights ...")
# Build the policy without triggering frozen-module download by constructing directly.
policy = LingBotVAPolicy(cfg)
# Near-identity remap: source transformer keys -> policy "transformer.*".
remapped = {f"transformer.{k}": v for k, v in src_sd.items()}
missing, unexpected = policy.load_state_dict(remapped, strict=False)
_log_load_keys(missing, unexpected)
policy = policy.to(torch_dtype)
print(f"Saving converted policy to {out}")
policy.save_pretrained(out)
preprocessor, postprocessor = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
preprocessor.save_pretrained(out)
postprocessor.save_pretrained(out)
if push_to:
print(f"Pushing to the Hub: {push_to}")
policy.push_to_hub(push_to)
preprocessor.push_to_hub(push_to)
postprocessor.push_to_hub(push_to)
print("Done.")
def _copy_frozen_subfolders(checkpoint: str, out: Path):
p = Path(checkpoint)
if not p.is_dir():
from huggingface_hub import snapshot_download
p = Path(snapshot_download(checkpoint, allow_patterns=["vae/*", "text_encoder/*", "tokenizer/*"]))
for sub in ("vae", "text_encoder", "tokenizer"):
src_sub = p / sub
if src_sub.is_dir():
shutil.copytree(src_sub, out / sub, dirs_exist_ok=True)
print(f" bundled {sub}/")
def _log_load_keys(missing, unexpected):
# The source transformer should account for every "transformer.*" key in the policy.
if missing:
print(
f" [load_state_dict] {len(missing)} missing keys (expected: none for transformer). Sample: {missing[:5]}"
)
if unexpected:
print(f" [load_state_dict] {len(unexpected)} unexpected keys. Sample: {unexpected[:5]}")
if not missing and not unexpected:
print(" [load_state_dict] perfect match (near-identity remap).")
def main():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--checkpoint", required=True, help="HF repo id or local diffusers-style directory.")
parser.add_argument("--variant", required=True, choices=sorted(VARIANTS.keys()))
parser.add_argument("--output_dir", required=True)
parser.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
parser.add_argument(
"--bundle-frozen",
action="store_true",
help="Copy the frozen vae/text_encoder/tokenizer next to the checkpoint instead of lazy-pulling.",
)
parser.add_argument(
"--push_to_hub", default=None, help="Optional HF repo id to push the converted policy to."
)
args = parser.parse_args()
convert(
checkpoint=args.checkpoint,
variant=args.variant,
output_dir=args.output_dir,
dtype=args.dtype,
bundle_frozen=args.bundle_frozen,
push_to=args.push_to_hub,
)
if __name__ == "__main__":
main()
@@ -0,0 +1,582 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LingBot-VA policy: an autoregressive video-action world model on the Wan2.2 stack.
The sampling loop is a faithful re-implementation of the upstream streaming server
(``wan_va/wan_va_server.py``) and LIBERO client (``evaluation/libero/client.py``), adapted
to LeRobot's ``select_action`` interface:
* the trainable dual-stream transformer is owned as a sub-module and round-trips in the
single ``model.safetensors`` checkpoint;
* the frozen Wan VAE + UMT5 text encoder + tokenizer are *lazily pulled* from
``config.wan_pretrained_path`` (not bundled), so the LeRobot checkpoint stays small;
* ``predict_action_chunk`` runs one autoregressive chunk (video stream then action
stream, each with CFG and its own flow-matching scheduler) and updates the KV cache;
* ``select_action`` drains a per-step action queue and records the real observed
keyframes that are fed back into the KV cache when the queue is refilled.
NOTE: matching the upstream LIBERO success rate is the Phase-5 correctness gate and must be
validated on a CUDA GPU with the converted checkpoint (tensor-diff against upstream on
identical inputs). The streaming path is written for single-environment eval
(``--eval.batch_size=1``).
"""
from collections import deque
import torch
import torch.nn.functional as F
from torch import Tensor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.import_utils import require_package
from .configuration_lingbot_va import LingBotVAConfig
from .schedulers import FlowMatchScheduler
from .wan_transformer import WanTransformer3DModel
from .wan_utils import data_seq_to_patch, get_mesh_id
from .wan_vae import WanVAEStreamingWrapper, denormalize_latents, load_text_encoder, load_tokenizer, load_vae
def _torch_dtype(name: str) -> torch.dtype:
return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[name]
class LingBotVAPolicy(PreTrainedPolicy):
"""LeRobot wrapper for the LingBot-VA autoregressive video-action world model."""
config_class = LingBotVAConfig
name = "lingbot_va"
def __init__(self, config: LingBotVAConfig, **kwargs):
require_package("diffusers", extra="lingbot_va")
require_package("transformers", extra="lingbot_va")
super().__init__(config)
config.validate_features()
self.config = config
self.dtype = _torch_dtype(config.dtype)
# Trainable dual-stream transformer (the only sub-module saved in the LeRobot checkpoint).
self.transformer = WanTransformer3DModel(
patch_size=tuple(config.patch_size),
num_attention_heads=config.num_attention_heads,
attention_head_dim=config.attention_head_dim,
in_channels=config.in_channels,
out_channels=config.out_channels,
action_dim=config.action_dim,
text_dim=config.text_dim,
freq_dim=config.freq_dim,
ffn_dim=config.ffn_dim,
num_layers=config.num_layers,
cross_attn_norm=config.cross_attn_norm,
eps=config.eps,
rope_max_seq_len=config.rope_max_seq_len,
attn_mode=config.attn_mode,
)
# Frozen modules are stored OUTSIDE the nn.Module registry (plain dict) so they are
# neither saved into model.safetensors nor moved by ``.to()``. They are lazily loaded
# from ``config.wan_pretrained_path`` the first time inference runs.
self._frozen: dict = {}
self.last_predicted_frames: Tensor | None = None
self.reset()
# ------------------------------------------------------------------
# Frozen-module lazy loading (VAE + UMT5 + tokenizer)
# ------------------------------------------------------------------
def _ensure_frozen_modules(self):
if self._frozen:
return
import os
path = self.config.wan_pretrained_path
device = self.config.device
# Support both local diffusers-style dirs (with vae/ text_encoder/ tokenizer/ sub-folders)
# and HF repo ids (loaders accept a subfolder kwarg, omitted here = repo root layout).
if os.path.isdir(path):
vae_path, te_path, tok_path = (
os.path.join(path, n) for n in ("vae", "text_encoder", "tokenizer")
)
else:
vae_path = te_path = tok_path = path
vae = load_vae(vae_path, torch_dtype=self.dtype, torch_device=device)
text_encoder = load_text_encoder(te_path, torch_dtype=self.dtype, torch_device=device)
tokenizer = load_tokenizer(tok_path)
self._frozen = {
"vae": vae.eval(),
"streaming_vae": WanVAEStreamingWrapper(vae),
"text_encoder": text_encoder.eval(),
"tokenizer": tokenizer,
}
@property
def _vae(self):
return self._frozen["vae"]
@property
def _streaming_vae(self):
return self._frozen["streaming_vae"]
# ------------------------------------------------------------------
# PreTrainedPolicy API
# ------------------------------------------------------------------
def get_optim_params(self) -> dict:
# Only the transformer is trainable; the VAE / text encoder stay frozen.
return self.transformer.parameters()
def reset(self):
"""Reset all per-episode streaming state (KV cache, queues, frame counter)."""
cfg = self.config
self._action_queue: deque = deque(maxlen=cfg.n_action_steps)
self._obs_buffer: list = [] # keyframe camera tensors observed during the current chunk
self._executed_actions: Tensor | None = (
None # last chunk's actions (model-normalized) for KV feedback
)
self._steps_since_refill = 0
self._frame_st_id = 0
self._first_chunk = True
self._prompt: str | None = None
self._prompt_embeds = None
self._negative_prompt_embeds = None
self.last_predicted_frames = None
self._use_cfg = (cfg.guidance_scale > 1) or (cfg.action_guidance_scale > 1)
# Two independent flow-matching schedulers (video latent + action streams).
self._scheduler = FlowMatchScheduler(shift=cfg.snr_shift, sigma_min=0.0, extra_one_step=True)
self._action_scheduler = FlowMatchScheduler(
shift=cfg.action_snr_shift, sigma_min=0.0, extra_one_step=True
)
self._scheduler.set_timesteps(1000, training=True)
self._action_scheduler.set_timesteps(1000, training=True)
self._cache_initialised = False
# Clear KV cache on the (already-built) transformer, if present.
if hasattr(self, "transformer"):
self.transformer.clear_cache("pos")
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Training loss. Implemented in the LingBot-VA training PR (Phase 7).
The flow-matching dual-stream loss needs the pre-extracted latent dataset
(see ``LatentLeRobotDataset`` upstream) and ``attn_mode='flex'``; it is intentionally
not part of this inference-focused integration.
"""
raise NotImplementedError(
"LingBot-VA training (flow-matching dual-stream loss) is part of the training port "
"(Phase 7 / PR #2) and is not implemented in this inference integration. "
"Use this policy for evaluation / inference."
)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Return one action, refilling the chunk (and feeding back observed keyframes) as needed."""
self.eval()
self._ensure_frozen_modules()
self._maybe_init_prompt(batch)
# Record the current observation as a keyframe at every frame boundary so that, when the
# queue empties, ``predict_action_chunk`` can feed the real observed frames back into the
# KV cache (mirroring the upstream ``compute_kv_cache`` call in the LIBERO client loop).
# We skip ``steps_since_refill == 0`` (the obs that conditioned the current chunk): only
# frames observed *after* executing each frame's actions are fed back.
if self._steps_since_refill > 0 and self._steps_since_refill % self.config.action_per_frame == 0:
self._obs_buffer.append(self._encode_obs(batch))
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch) # [B, chunk_size, n_used]
# queue holds per-step actions: shape [chunk_size, B, n_used]
self._action_queue.extend(actions.transpose(0, 1))
self._steps_since_refill = 0
self._steps_since_refill += 1
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Run one autoregressive chunk and return actions ``[B, chunk_size, n_used]`` (normalized)."""
self.eval()
self._ensure_frozen_modules()
self._maybe_init_prompt(batch)
is_first = self._first_chunk
if is_first:
init_latent = self._encode_obs(batch)
self._init_latent = init_latent
self._init_streaming_cache(init_latent)
self._obs_buffer = [] # frame 0 (the init obs) conditions the chunk; it is not fed back
actions, latents = self._infer(init_latent, frame_st_id=0)
self._first_chunk = False
else:
# Feed the real observed keyframes + the executed actions back into the KV cache.
self._compute_kv_cache(self._obs_buffer, self._executed_actions)
self._obs_buffer = []
actions, latents = self._infer(None, frame_st_id=self._frame_st_id)
# actions: [B, action_dim, F, action_per_frame, 1] (model-normalized). Keep for KV feedback.
self._executed_actions = actions
if self.config.save_predicted_video:
self.last_predicted_frames = self._decode_predicted_video(latents)
# On the first chunk, frame 0 is the conditioning frame (already "known"): the upstream
# LIBERO client skips it (start_idx=1), so we drop the first frame's actions here.
used = self.config.used_action_channel_ids
a = actions[:, used] # [B, n_used, F, action_per_frame, 1]
if is_first:
a = a[:, :, 1:] # drop frame 0 -> (F-1) frames of actions
a = a.squeeze(-1).flatten(2) # [B, n_used, n_steps]
a = a.transpose(1, 2).contiguous() # [B, n_steps, n_used]
return a.to(torch.float32)
# ------------------------------------------------------------------
# Prompt / text encoding
# ------------------------------------------------------------------
def _maybe_init_prompt(self, batch):
if self._prompt_embeds is not None:
return
task = batch.get("task")
prompt = task[0] if isinstance(task, list | tuple) else task
self._prompt = prompt or ""
self._prompt_embeds, self._negative_prompt_embeds = self._encode_prompt(self._prompt)
def _get_t5_prompt_embeds(self, prompt, max_sequence_length):
from diffusers.pipelines.wan.pipeline_wan import prompt_clean
tokenizer = self._frozen["tokenizer"]
text_encoder = self._frozen["text_encoder"]
device = self.config.device
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
te_device = next(text_encoder.parameters()).device
prompt_embeds = text_encoder(text_input_ids.to(te_device), mask.to(te_device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens, strict=False)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds],
dim=0,
)
return prompt_embeds.to(device)
def _encode_prompt(self, prompt):
max_len = self.config.max_sequence_length
prompt_embeds = self._get_t5_prompt_embeds(prompt, max_len)
negative_prompt_embeds = None
if self._use_cfg:
negative_prompt_embeds = self._get_t5_prompt_embeds("", max_len)
return prompt_embeds, negative_prompt_embeds
# ------------------------------------------------------------------
# Observation (image) encoding -> normalized video latents
# ------------------------------------------------------------------
def _camera_tensor(self, batch, key):
"""Return a single-frame camera tensor [B, C, 1, H, W] resized + scaled to [-1, 1]."""
img = batch[key]
if img.dim() == 3: # [C, H, W]
img = img.unsqueeze(0)
# LeRobot images arrive as float in [0, 1], shape [B, C, H, W].
img = img.to(self.config.device, torch.float32)
img = F.interpolate(
img, size=(self.config.height, self.config.width), mode="bilinear", align_corners=False
)
img = img * 2.0 - 1.0
return img.unsqueeze(2).to(self.dtype) # [B, C, F=1, H, W]
@torch.no_grad()
def _encode_obs(self, batch) -> Tensor:
"""VAE-encode all configured cameras of the current obs and concat latents on width."""
videos = [self._camera_tensor(batch, k) for k in self.config.obs_cam_keys]
videos = torch.cat(videos, dim=0) # [num_cam, C, F, H, W]
vae_device = next(self._vae.parameters()).device
enc_out = self._streaming_vae.encode_chunk(videos.to(vae_device).to(self.dtype))
mu, _logvar = torch.chunk(enc_out, 2, dim=1)
latents_mean = torch.tensor(self._vae.config.latents_mean).to(mu.device)
latents_std = torch.tensor(self._vae.config.latents_std).to(mu.device)
# Note: upstream passes 1/std so the op is (x - mean) * (1/std).
mean = latents_mean.view(1, -1, 1, 1, 1)
inv_std = (1.0 / latents_std).view(1, -1, 1, 1, 1)
mu_norm = ((mu.float() - mean) * inv_std).to(mu)
# Concatenate the per-camera latents along width.
video_latent = torch.cat(mu_norm.split(1, dim=0), dim=-1)
return video_latent.to(self.config.device)
# ------------------------------------------------------------------
# KV cache management
# ------------------------------------------------------------------
@property
def _latent_hw(self):
h = self.config.height // 16
w = (self.config.width // 16) * len(self.config.obs_cam_keys)
return h, w
def _init_streaming_cache(self, init_latent):
cfg = self.config
latent_h, latent_w = self._latent_hw
p = cfg.patch_size
latent_token_per_chunk = (cfg.frame_chunk_size * latent_h * latent_w) // (p[0] * p[1] * p[2])
action_token_per_chunk = cfg.frame_chunk_size * cfg.action_per_frame
self.transformer.create_empty_cache(
"pos",
cfg.attn_window,
latent_token_per_chunk,
action_token_per_chunk,
device=self.config.device,
dtype=self.dtype,
batch_size=2 if self._use_cfg else 1,
)
self._cache_initialised = True
def _repeat_input_for_cfg(self, input_dict):
if self._use_cfg:
input_dict["noisy_latents"] = input_dict["noisy_latents"].repeat(2, 1, 1, 1, 1)
input_dict["text_emb"] = torch.cat(
[
self._prompt_embeds.to(self.dtype).clone(),
self._negative_prompt_embeds.to(self.dtype).clone(),
],
dim=0,
)
input_dict["grid_id"] = input_dict["grid_id"][None].repeat(2, 1, 1)
input_dict["timesteps"] = input_dict["timesteps"][None].repeat(2, 1)
else:
input_dict["grid_id"] = input_dict["grid_id"][None]
input_dict["timesteps"] = input_dict["timesteps"][None]
return input_dict
def _prepare_latent_input(
self,
latent_model_input,
action_model_input,
latent_t=0,
action_t=0,
latent_cond=None,
action_cond=None,
frame_st_id=0,
):
cfg = self.config
device = self.config.device
p = cfg.patch_size
out = {}
if latent_model_input is not None:
out["latent_res_lst"] = {
"noisy_latents": latent_model_input,
"timesteps": torch.ones([latent_model_input.shape[2]], dtype=torch.float32, device=device)
* latent_t,
"grid_id": get_mesh_id(
latent_model_input.shape[-3] // p[0],
latent_model_input.shape[-2] // p[1],
latent_model_input.shape[-1] // p[2],
0,
1,
frame_st_id,
).to(device),
"text_emb": self._prompt_embeds.to(self.dtype).clone(),
}
if latent_cond is not None:
out["latent_res_lst"]["noisy_latents"][:, :, 0:1] = latent_cond[:, :, 0:1]
out["latent_res_lst"]["timesteps"][0:1] *= 0
if action_model_input is not None:
out["action_res_lst"] = {
"noisy_latents": action_model_input,
"timesteps": torch.ones([action_model_input.shape[2]], dtype=torch.float32, device=device)
* action_t,
"grid_id": get_mesh_id(
action_model_input.shape[-3],
action_model_input.shape[-2],
action_model_input.shape[-1],
1,
1,
frame_st_id,
action=True,
).to(device),
"text_emb": self._prompt_embeds.to(self.dtype).clone(),
}
if action_cond is not None:
out["action_res_lst"]["noisy_latents"][:, :, 0:1] = action_cond[:, :, 0:1]
out["action_res_lst"]["timesteps"][0:1] *= 0
out["action_res_lst"]["noisy_latents"][:, ~self._action_mask] *= 0
return out
@property
def _action_mask(self):
mask = torch.zeros([self.config.action_dim], dtype=torch.bool)
mask[self.config.used_action_channel_ids] = True
return mask
# ------------------------------------------------------------------
# Action conditioning (executed action history) (de)normalization
# ------------------------------------------------------------------
def _preprocess_action_state(self, action_norm: Tensor) -> Tensor:
"""Build the action-conditioning tensor from the already-normalized executed actions.
``action_norm`` is the model-space action chunk ``[B, action_dim, F, action_per_frame, 1]``.
Upstream re-derives the conditioning from the raw executed action via quantile norm; here
the executed actions are already in the model-normalized space, so we pass them through.
"""
return action_norm.to(self.config.device, self.dtype)
def _compute_kv_cache(self, obs_buffer, executed_actions):
"""Feed real observed keyframes + executed actions back into the KV cache."""
if not obs_buffer or executed_actions is None:
return
self.transformer.clear_pred_cache("pos")
# Concatenate the observed keyframe latents along the frame axis.
latent_model_input = torch.cat(obs_buffer, dim=2)
# On the first feedback, prepend the init latent so the latent/action frame counts align
# (upstream prepends ``init_latent`` to the observed keyframes when frame_st_id == 0).
if self._frame_st_id == 0 and getattr(self, "_init_latent", None) is not None:
latent_model_input = torch.cat([self._init_latent, latent_model_input], dim=2)
action_model_input = self._preprocess_action_state(executed_actions)
action_model_input = action_model_input.to(latent_model_input)
input_dict = self._prepare_latent_input(
latent_model_input, action_model_input, frame_st_id=self._frame_st_id
)
with torch.no_grad():
self.transformer(
self._repeat_input_for_cfg(input_dict["latent_res_lst"]),
update_cache=2,
cache_name="pos",
action_mode=False,
)
self.transformer(
self._repeat_input_for_cfg(input_dict["action_res_lst"]),
update_cache=2,
cache_name="pos",
action_mode=True,
)
self._frame_st_id += latent_model_input.shape[2]
# ------------------------------------------------------------------
# The core dual-stream denoising loop (one chunk)
# ------------------------------------------------------------------
@torch.no_grad()
def _infer(self, init_latent, frame_st_id=0):
cfg = self.config
device = self.config.device
latent_h, latent_w = self._latent_hw
frame_chunk_size = cfg.frame_chunk_size
latents = torch.randn(1, 48, frame_chunk_size, latent_h, latent_w, device=device, dtype=self.dtype)
actions = torch.randn(
1, cfg.action_dim, frame_chunk_size, cfg.action_per_frame, 1, device=device, dtype=self.dtype
)
self._scheduler.set_timesteps(cfg.num_inference_steps)
self._action_scheduler.set_timesteps(cfg.action_num_inference_steps)
timesteps = F.pad(self._scheduler.timesteps, (0, 1), mode="constant", value=0)
if cfg.video_exec_step != -1:
timesteps = timesteps[: cfg.video_exec_step]
action_timesteps = F.pad(self._action_scheduler.timesteps, (0, 1), mode="constant", value=0)
# 1. Video-latent denoising loop
for i, t in enumerate(timesteps):
last_step = i == len(timesteps) - 1
latent_cond = (
init_latent[:, :, 0:1].to(self.dtype)
if frame_st_id == 0 and init_latent is not None
else None
)
input_dict = self._prepare_latent_input(
latents, None, t, t, latent_cond, None, frame_st_id=frame_st_id
)
video_noise_pred = self.transformer(
self._repeat_input_for_cfg(input_dict["latent_res_lst"]),
update_cache=1 if last_step else 0,
cache_name="pos",
action_mode=False,
)
if not last_step or cfg.video_exec_step != -1:
video_noise_pred = data_seq_to_patch(
cfg.patch_size,
video_noise_pred,
frame_chunk_size,
latent_h,
latent_w,
batch_size=2 if self._use_cfg else 1,
)
if cfg.guidance_scale > 1:
video_noise_pred = video_noise_pred[1:] + cfg.guidance_scale * (
video_noise_pred[:1] - video_noise_pred[1:]
)
else:
video_noise_pred = video_noise_pred[:1]
latents = self._scheduler.step(video_noise_pred, t, latents, return_dict=False)
if frame_st_id == 0 and latent_cond is not None:
latents[:, :, 0:1] = latent_cond
# 2. Action denoising loop
for i, t in enumerate(action_timesteps):
last_step = i == len(action_timesteps) - 1
action_cond = (
torch.zeros([1, cfg.action_dim, 1, cfg.action_per_frame, 1], device=device, dtype=self.dtype)
if frame_st_id == 0
else None
)
input_dict = self._prepare_latent_input(
None, actions, t, t, None, action_cond, frame_st_id=frame_st_id
)
action_noise_pred = self.transformer(
self._repeat_input_for_cfg(input_dict["action_res_lst"]),
update_cache=1 if last_step else 0,
cache_name="pos",
action_mode=True,
)
if not last_step:
from einops import rearrange
action_noise_pred = rearrange(action_noise_pred, "b (f n) c -> b c f n 1", f=frame_chunk_size)
if cfg.action_guidance_scale > 1:
action_noise_pred = action_noise_pred[1:] + cfg.action_guidance_scale * (
action_noise_pred[:1] - action_noise_pred[1:]
)
else:
action_noise_pred = action_noise_pred[:1]
actions = self._action_scheduler.step(action_noise_pred, t, actions, return_dict=False)
if frame_st_id == 0 and action_cond is not None:
actions[:, :, 0:1] = action_cond
actions[:, ~self._action_mask] *= 0
return actions, latents
# ------------------------------------------------------------------
# Predicted-video decoding (opt-in)
# ------------------------------------------------------------------
@torch.no_grad()
def _decode_predicted_video(self, latents) -> Tensor:
"""VAE-decode predicted latents into a uint8 frame stack ``[T, H, W, 3]`` on CPU."""
vae = self._vae
z_dim = vae.config.z_dim
latents = denormalize_latents(
latents.to(vae.dtype), vae.config.latents_mean, vae.config.latents_std, z_dim
)
video = vae.decode(latents, return_dict=False)[0] # [B, C, F, H, W] in [-1, 1]
video = (video.float().clamp(-1, 1) + 1.0) / 2.0
video = (video[0].permute(1, 2, 3, 0) * 255.0).round().to(torch.uint8) # [F, H, W, C]
return video.cpu()
@@ -0,0 +1,113 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pre/post-processor pipelines for the LingBot-VA policy.
The policy itself handles image resizing, scaling to [-1, 1] and VAE encoding (the VAE
lives inside the policy), so the preprocessor only renames, batches, normalizes (IDENTITY)
and moves to device. The postprocessor reverses the *fixed* action quantile normalization
(``(action + 1) / 2 * (q99 - q01 + 1e-6) + q01``) baked into the released checkpoints this
is a fixed transform, not a dataset-stats one, so it cannot use the standard
``UnnormalizerProcessorStep`` and is implemented as a dedicated step below.
"""
from dataclasses import dataclass, field
from typing import Any
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyActionProcessorStep,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from .configuration_lingbot_va import LingBotVAConfig
@dataclass
@ProcessorStepRegistry.register(name="lingbot_va_action_unnormalize")
class LingBotVAActionUnnormalizeStep(PolicyActionProcessorStep):
"""Reverse LingBot-VA's fixed per-channel quantile normalization on predicted actions.
The policy emits actions in the normalized ``[-1, 1]`` space of the used action channels.
This step maps them back to physical units via the fixed quantiles stored in the config.
"""
action_q01: list[float] = field(default_factory=list)
action_q99: list[float] = field(default_factory=list)
def action(self, action: PolicyAction) -> PolicyAction:
q01 = torch.as_tensor(self.action_q01, dtype=action.dtype, device=action.device)
q99 = torch.as_tensor(self.action_q99, dtype=action.dtype, device=action.device)
return (action + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01
def get_config(self) -> dict[str, Any]:
return {"action_q01": list(self.action_q01), "action_q99": list(self.action_q99)}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def make_lingbot_va_pre_post_processors(
config: LingBotVAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build the pre/post processor pipelines for LingBot-VA."""
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device=config.device),
]
output_steps: list[ProcessorStep] = [
LingBotVAActionUnnormalizeStep(action_q01=config.action_q01, action_q99=config.action_q99),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -0,0 +1,155 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flow-matching scheduler for LingBot-VA.
Vendored verbatim from the upstream LingBot-VA repository
(https://github.com/Robbyant/lingbot-va, ``wan_va/utils/scheduler.py``). LingBot-VA uses
two independent instances of this scheduler at inference time one for the video-latent
stream and one for the action stream each with its own ``shift`` (signal-to-noise ratio
shift) and number of denoising steps.
"""
import math
import torch
__all__ = ["FlowMatchScheduler"]
class FlowMatchScheduler:
def __init__(
self,
num_inference_steps=100,
num_train_timesteps=1000,
shift=3.0,
sigma_max=1.0,
sigma_min=0.003 / 1.002,
inverse_timesteps=False,
extra_one_step=False,
reverse_sigmas=False,
exponential_shift=False,
exponential_shift_mu=None,
shift_terminal=None,
):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.reverse_sigmas = reverse_sigmas
self.exponential_shift = exponential_shift
self.exponential_shift_mu = exponential_shift_mu
self.shift_terminal = shift_terminal
self.set_timesteps(num_inference_steps)
def set_timesteps(
self,
num_inference_steps=100,
denoising_strength=1.0,
training=False,
shift=None,
dynamic_shift_len=None,
):
if shift is not None:
self.shift = shift
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
if self.exponential_shift:
mu = (
self.calculate_shift(dynamic_shift_len)
if dynamic_shift_len is not None
else self.exponential_shift_mu
)
self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
else:
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
if self.shift_terminal is not None:
one_minus_z = 1 - self.sigmas
scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
self.sigmas = 1 - (one_minus_z / scale_factor)
if self.reverse_sigmas:
self.sigmas = 1 - self.sigmas
self.timesteps = self.sigmas * self.num_train_timesteps
if training:
x = self.timesteps
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
self.training = True
else:
self.training = False
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
def return_to_timestep(self, timestep, sample, sample_stablized):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
model_output = (sample - sample_stablized) / sigma
return model_output
def add_noise(self, original_samples, noise, timestep, t_dim=2):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep = timestep[None]
timestep_id = torch.argmin((self.timesteps[:, None] - timestep).abs(), dim=0)
shape = [1] * noise.ndim
shape[t_dim] = timestep_id.shape[0]
sigma = self.sigmas[timestep_id].to(original_samples).view(shape)
sample = (1 - sigma) * original_samples + sigma * noise
return sample
def training_target(self, sample, noise, timestep):
target = noise - sample
return target
def training_weight(self, timestep):
timestep_id = torch.argmin(
(self.timesteps[:, None].to(timestep.device) - timestep[None]).abs(), dim=0
)
weights = self.linear_timesteps_weights.to(timestep.device)[timestep_id].to(timestep.device)
return weights
def calculate_shift(
self,
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 8192,
base_shift: float = 0.5,
max_shift: float = 0.9,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
@@ -0,0 +1,286 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention and rotary-position-embedding modules for the LingBot-VA Wan transformer.
Vendored and lightly adapted from the upstream LingBot-VA repository
(https://github.com/Robbyant/lingbot-va, ``wan_va/modules/model.py``). The ``torch``
SDPA backend is the default and is always available; the ``flashattn`` and ``flex``
backends are imported lazily and only required when the corresponding ``attn_mode`` is
selected. State-dict parameter names are preserved verbatim so that conversion from the
original diffusers-style checkpoint is near-identity.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# ``flash_attn`` and the flex-attention APIs are optional. We import them lazily inside the
# backends that need them so that the (default) ``torch`` SDPA path works on any platform,
# including CPU-only and macOS where neither package is available.
def custom_sdpa(q, k, v):
"""Scaled-dot-product attention operating on ``(B, S, H, D)`` tensors."""
out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2))
return out.transpose(1, 2)
def _load_flash_attn_func():
try:
from flash_attn_interface import flash_attn_func
except ImportError:
try:
from flash_attn import flash_attn_func
except ImportError as e:
raise ImportError(
"attn_mode='flashattn' requires the `flash_attn` package, which is not installed. "
"Install it, or use attn_mode='torch' (the default)."
) from e
return flash_attn_func
class WanRotaryPosEmbed(nn.Module):
"""Rotary position embedding with separate frequency bases for frame / height / width."""
def __init__(
self,
attention_head_dim: int,
patch_size,
max_seq_len: int,
theta: float = 10000.0,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
self.theta = theta
self.f_dim = self.attention_head_dim - 2 * (self.attention_head_dim // 3)
self.h_dim = self.attention_head_dim // 3
self.w_dim = self.attention_head_dim // 3
f_freqs_base, h_freqs_base, w_freqs_base = self._precompute_freqs_base()
self.f_freqs_base = f_freqs_base
self.h_freqs_base = h_freqs_base
self.w_freqs_base = w_freqs_base
def _precompute_freqs_base(self):
# freqs_base = 1.0 / (theta ** (2k / dim))
f_freqs_base = 1.0 / (
self.theta ** (torch.arange(0, self.f_dim, 2)[: (self.f_dim // 2)].double() / self.f_dim)
)
h_freqs_base = 1.0 / (
self.theta ** (torch.arange(0, self.h_dim, 2)[: (self.h_dim // 2)].double() / self.h_dim)
)
w_freqs_base = 1.0 / (
self.theta ** (torch.arange(0, self.w_dim, 2)[: (self.w_dim // 2)].double() / self.w_dim)
)
return f_freqs_base, h_freqs_base, w_freqs_base
def forward(self, grid_ids):
with torch.no_grad():
f_freqs = grid_ids[:, 0, :].unsqueeze(-1) * self.f_freqs_base.to(grid_ids.device)
h_freqs = grid_ids[:, 1, :].unsqueeze(-1) * self.h_freqs_base.to(grid_ids.device)
w_freqs = grid_ids[:, 2, :].unsqueeze(-1) * self.w_freqs_base.to(grid_ids.device)
freqs = torch.cat([f_freqs, h_freqs, w_freqs], dim=-1).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
class WanAttention(nn.Module):
"""Self/cross attention with KV-caching for autoregressive streaming inference.
Backends:
* ``torch`` (default): standard SDPA, available everywhere.
* ``flashattn``: FlashAttention kernels (optional dependency).
* ``flex``: PyTorch flex-attention (optional, used for block-causal training masks).
"""
def __init__(
self,
dim,
heads=8,
dim_head=64,
eps=1e-5,
dropout=0.0,
cross_attention_dim_head=None,
attn_mode="torch",
):
super().__init__()
if attn_mode == "torch":
self.attn_op = custom_sdpa
elif attn_mode == "flashattn":
self.attn_op = _load_flash_attn_func()
elif attn_mode == "flex":
# Imported lazily to avoid a hard dependency on torch flex-attention at import time.
from .wan_flex_attention import FlexAttnFunc
self.attn_op = FlexAttnFunc(cross_attention_dim_head is not None)
else:
raise ValueError(
f"Unsupported attention mode: {attn_mode}, only support 'torch', 'flashattn' and 'flex'"
)
self.inner_dim = dim_head * heads
self.heads = heads
self.cross_attention_dim_head = cross_attention_dim_head
self.kv_inner_dim = (
self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
)
self.to_q = nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = nn.ModuleList(
[
nn.Linear(self.inner_dim, dim, bias=True),
nn.Dropout(dropout),
]
)
self.norm_q = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.norm_k = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
# KV cache only lives on self-attention modules (cross_attention_dim_head is None).
self.attn_caches = {} if cross_attention_dim_head is None else None
def clear_pred_cache(self, cache_name):
if self.attn_caches is None:
return
cache = self.attn_caches[cache_name]
is_pred = cache["is_pred"]
cache["mask"][is_pred] = False
def clear_cache(self, cache_name):
if self.attn_caches is None:
return
self.attn_caches[cache_name] = None
def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim, device, dtype, batch_size):
if self.attn_caches is None:
return
self.attn_caches[cache_name] = {
"k": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype),
"v": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype),
"id": torch.full((total_tolen,), -1, device=device),
"mask": torch.zeros((total_tolen,), dtype=torch.bool, device=device),
"is_pred": torch.zeros((total_tolen,), dtype=torch.bool, device=device),
}
def allocate_slots(self, cache_name, key_size):
cache = self.attn_caches[cache_name]
mask = cache["mask"]
ids = cache["id"]
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
if free.numel() < key_size:
used = mask.nonzero(as_tuple=False).squeeze(-1)
used_ids = ids[used]
order = torch.argsort(used_ids)
need = key_size - free.numel()
to_free = used[order[:need]]
mask[to_free] = False
ids[to_free] = -1
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
assert free.numel() >= key_size
return free[:key_size]
def _next_cache_id(self, cache_name):
ids = self.attn_caches[cache_name]["id"]
mask = self.attn_caches[cache_name]["mask"]
if mask.any():
return ids[mask].max() + 1
else:
return torch.tensor(0, device=ids.device, dtype=ids.dtype)
def update_cache(self, cache_name, key, value, is_pred):
cache = self.attn_caches[cache_name]
key_size = key.shape[1]
slots = self.allocate_slots(cache_name, key_size)
new_id = self._next_cache_id(cache_name)
cache["k"][:, slots] = key
cache["v"][:, slots] = value
cache["mask"][slots] = True
cache["id"][slots] = new_id
cache["is_pred"][slots] = is_pred
return slots
def restore_cache(self, cache_name, slots):
self.attn_caches[cache_name]["mask"][slots] = False
def forward(
self,
q,
k,
v,
rotary_emb,
update_cache=0,
cache_name="pos",
):
kv_cache = (
self.attn_caches[cache_name]
if (self.attn_caches is not None) and (cache_name in self.attn_caches)
else None
)
query, key, value = self.to_q(q), self.to_k(k), self.to_v(v)
query = self.norm_q(query)
query = query.unflatten(2, (self.heads, -1))
key = self.norm_k(key)
key = key.unflatten(2, (self.heads, -1))
value = value.unflatten(2, (self.heads, -1))
if rotary_emb is not None:
def apply_rotary_emb(x, freqs):
x_out = torch.view_as_complex(
x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
)
x_out = torch.view_as_real(x_out * freqs).flatten(3)
return x_out.to(x.dtype)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
slots = None
if kv_cache is not None and kv_cache["k"] is not None:
slots = self.update_cache(cache_name, key, value, is_pred=(update_cache == 1))
key_pool = self.attn_caches[cache_name]["k"]
value_pool = self.attn_caches[cache_name]["v"]
mask = self.attn_caches[cache_name]["mask"]
valid = mask.nonzero(as_tuple=False).squeeze(-1)
key = key_pool[:, valid]
value = value_pool[:, valid]
hidden_states = self.attn_op(query, key, value)
if update_cache == 0:
if kv_cache is not None and kv_cache["k"] is not None:
self.restore_cache(cache_name, slots)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
__all__ = ["WanAttention", "WanRotaryPosEmbed", "custom_sdpa"]
@@ -0,0 +1,207 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flex-attention backend for the LingBot-VA Wan transformer (training only).
This module is imported lazily and ONLY when ``attn_mode='flex'`` is requested. It builds
the block-causal / window / noise-vs-clean attention masks used during the dual-stream
flow-matching training described in the LingBot-VA paper. Inference uses the ``torch``
SDPA backend (see :mod:`wan_attention`) which does not need flex-attention.
``torch.nn.attention.flex_attention`` requires a recent PyTorch build with the relevant
inductor support; importing this module on an unsupported build raises ``ImportError``.
"""
from collections.abc import Callable
from functools import partial
from typing import ClassVar
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.attention.flex_attention import (
BlockMask,
and_masks,
create_block_mask,
flex_attention,
or_masks,
)
class FlexAttnFunc(nn.Module):
flex_attn: ClassVar[Callable] = torch.compile(flex_attention, dynamic=True)
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
attention_mask: ClassVar[BlockMask] = None
cross_attention_mask: ClassVar[BlockMask] = None
def __init__(self, is_cross=False) -> None:
super().__init__()
self.is_cross = is_cross
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype=torch.bfloat16,
) -> torch.Tensor:
q_varlen = rearrange(query[0], "s n d -> 1 n s d")
k_varlen = rearrange(key[0], "s n d -> 1 n s d")
v_varlen = rearrange(value[0], "s n d -> 1 n s d")
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
q_varlen = half(q_varlen)
k_varlen = half(k_varlen)
v_varlen = half(v_varlen)
q_varlen = q_varlen.to(v_varlen.dtype)
k_varlen = k_varlen.to(v_varlen.dtype)
block_mask = FlexAttnFunc.cross_attention_mask if self.is_cross else FlexAttnFunc.attention_mask
x_out = FlexAttnFunc.flex_attn(
q_varlen,
k_varlen,
v_varlen,
block_mask=block_mask,
kernel_options={
"BLOCK_M": 64,
"BLOCK_N": 64,
"BLOCK_M1": 32,
"BLOCK_N1": 64,
"BLOCK_M2": 64,
"BLOCK_N2": 32,
},
)
x_out = rearrange(x_out, "b n s d -> b s n d")
return x_out
@staticmethod
@torch.no_grad()
def init_mask(
latent_shape,
action_shape,
padded_length,
chunk_size,
window_size,
patch_size,
device,
):
torch._inductor.config.realize_opcount_threshold = 100
B, _, L_F, L_H, L_W = latent_shape
_, _, A_F, A_H, A_W = action_shape
latent_seq_id = (
torch.arange(B)[:, None, None, None]
.expand(-1, L_F // patch_size[0], L_H // patch_size[1], L_W // patch_size[2])
.flatten()
)
action_seq_id = torch.arange(B)[:, None, None, None].expand(-1, A_F, A_H, A_W).flatten()
seq_ids = torch.cat([latent_seq_id] * 2 + [action_seq_id] * 2)
latent_frame_id = (
torch.arange(L_F)[None, :, None, None]
.expand(B, -1, L_H // patch_size[1], L_W // patch_size[2])[None]
.flatten()
)
action_frame_id = torch.arange(A_F)[None, :, None, None].expand(B, -1, A_H, A_W)[None].flatten()
frame_ids = torch.cat(
[latent_frame_id // chunk_size * 2] * 2 + [action_frame_id // chunk_size * 2 + 1] * 2
)
noise_ids = torch.cat(
[
torch.zeros_like(latent_frame_id),
torch.ones_like(latent_frame_id),
torch.zeros_like(action_frame_id),
torch.ones_like(action_frame_id),
]
)
seq_ids = F.pad(seq_ids, (0, padded_length), value=-1)
frame_ids = F.pad(frame_ids, (0, padded_length), value=-1)
noise_ids = F.pad(noise_ids, (0, padded_length), value=-1)
mask_mod = FlexAttnFunc._get_mask_mod(
seq_ids.long().to(device), frame_ids.long().to(device), noise_ids.long().to(device), window_size
)
block_mask = FlexAttnFunc.compiled_create_block_mask(
mask_mod, 1, 1, len(seq_ids), len(seq_ids), device=device, _compile=True
)
FlexAttnFunc.attention_mask = block_mask
text_seq_ids = torch.arange(B)[:, None].expand(-1, 512).flatten()
mask_mod_cross = FlexAttnFunc._get_cross_mask_mod(
seq_ids.long().to(device), text_seq_ids.long().to(device)
)
block_mask_cross = FlexAttnFunc.compiled_create_block_mask(
mask_mod_cross, 1, 1, len(seq_ids), len(text_seq_ids), device=device, _compile=True
)
FlexAttnFunc.cross_attention_mask = block_mask_cross
@staticmethod
@torch.no_grad()
def _get_cross_mask_mod(seq_ids, text_seq_ids):
def seq_mask(b, h, q_idx, kv_idx):
return (
(seq_ids[q_idx] == text_seq_ids[kv_idx]) & (seq_ids[q_idx] >= 0) & (text_seq_ids[kv_idx] >= 0)
)
return seq_mask
@staticmethod
@torch.no_grad()
def _get_mask_mod(seq_ids, frame_ids, noise_ids, window_size):
def seq_mask(b, h, q_idx, kv_idx):
return (seq_ids[q_idx] == seq_ids[kv_idx]) & (seq_ids[q_idx] >= 0) & (seq_ids[kv_idx] >= 0)
def block_causal_mask(b, h, q_idx, kv_idx):
return frame_ids[kv_idx] <= frame_ids[q_idx]
def block_causal_mask_exclude_self(b, h, q_idx, kv_idx):
return frame_ids[kv_idx] < frame_ids[q_idx]
def block_self_mask(b, h, q_idx, kv_idx):
return frame_ids[kv_idx] == frame_ids[q_idx]
def clean2clean_mask(b, h, q_idx, kv_idx):
return (noise_ids[q_idx] == 1) & (noise_ids[kv_idx] == 1)
def noise2clean_mask(b, h, q_idx, kv_idx):
return (noise_ids[q_idx] == 0) & (noise_ids[kv_idx] == 1)
def noise2noise_mask(b, h, q_idx, kv_idx):
return (noise_ids[q_idx] == 0) & (noise_ids[kv_idx] == 0)
def block_window_mask(b, h, q_idx, kv_idx, window_size: int):
return (frame_ids[q_idx] - frame_ids[kv_idx]).abs() <= window_size
mask_list = []
mask_list.append(and_masks(clean2clean_mask, block_causal_mask))
mask_list.append(and_masks(noise2clean_mask, block_causal_mask_exclude_self))
mask_list.append(and_masks(noise2noise_mask, block_self_mask))
mask = or_masks(*mask_list)
mask = and_masks(mask, seq_mask)
mask = and_masks(mask, partial(block_window_mask, window_size=window_size))
return mask
__all__ = ["FlexAttnFunc"]
@@ -0,0 +1,514 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The dual-stream Wan2.2 video-action transformer backbone for LingBot-VA.
Vendored and lightly adapted from the upstream LingBot-VA repository
(https://github.com/Robbyant/lingbot-va, ``wan_va/modules/model.py``).
The model keeps the diffusers ``ModelMixin``/``ConfigMixin`` mixins so the original
sharded ``transformer/`` checkpoint can be loaded with ``from_pretrained`` during
conversion, but in LeRobot it is owned as a plain ``nn.Module`` sub-component of
:class:`~lerobot.policies.lingbot_va.modeling_lingbot_va.LingBotVAPolicy`. State-dict
parameter names are preserved verbatim so conversion is near-identity.
"""
import math
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import (
PixArtAlphaTextProjection,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import FP32LayerNorm
from einops import rearrange
from .wan_attention import WanAttention, WanRotaryPosEmbed
__all__ = ["WanTransformer3DModel", "WanTransformerBlock", "WanTimeTextImageEmbedding"]
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim,
time_freq_dim,
time_proj_dim,
text_embed_dim,
pos_embed_seq_len,
):
super().__init__()
self.timesteps_proj = Timesteps(
num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0
)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
def forward(self, timestep: torch.Tensor, dtype=None):
B, L = timestep.shape
timestep = timestep.reshape(-1)
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).to(dtype=dtype)
timestep_proj = self.time_proj(self.act_fn(temb))
return temb.reshape(B, L, -1), timestep_proj.reshape(B, L, -1)
class WanTransformerBlock(nn.Module):
def __init__(
self,
dim,
ffn_dim,
num_heads,
cross_attn_norm=False,
eps=1e-6,
attn_mode: str = "torch",
):
super().__init__()
self.attn_mode = attn_mode
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = WanAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
cross_attention_dim_head=None,
attn_mode=attn_mode,
)
# 2. Cross-attention
self.attn2 = WanAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
cross_attention_dim_head=dim // num_heads,
attn_mode=attn_mode,
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb,
update_cache=0,
cache_name="pos",
) -> torch.Tensor:
temb_scale_shift_table = self.scale_shift_table[None] + temb.float()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = rearrange(
temb_scale_shift_table, "b l n c -> b n l c"
).chunk(6, dim=1)
shift_msa = shift_msa.squeeze(1)
scale_msa = scale_msa.squeeze(1)
gate_msa = gate_msa.squeeze(1)
c_shift_msa = c_shift_msa.squeeze(1)
c_scale_msa = c_scale_msa.squeeze(1)
c_gate_msa = c_gate_msa.squeeze(1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1.0 + scale_msa) + shift_msa).type_as(
hidden_states
)
attn_output = self.attn1(
norm_hidden_states,
norm_hidden_states,
norm_hidden_states,
rotary_emb,
update_cache=update_cache,
cache_name=cache_name,
)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states,
encoder_hidden_states,
None,
update_cache=0,
cache_name=cache_name,
)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1.0 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class WanTransformer3DModel(ModelMixin, ConfigMixin):
"""Dual-stream (video + action) Wan2.2 DiT backbone with autoregressive KV caching."""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = [
"patch_embedding_mlp",
"condition_embedder",
"condition_embedder_action",
"norm",
]
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = [
"time_embedder",
"scale_shift_table",
"scale_shift_table_action",
"norm1",
"action_norm1",
"text_norm1",
"norm2",
"action_norm2",
"text_norm2",
"norm3",
"action_norm3",
"text_norm3",
]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]
@register_to_config
def __init__(
self,
patch_size=(1, 2, 2),
num_attention_heads=24,
attention_head_dim=128,
in_channels=48,
out_channels=48,
action_dim=30,
text_dim=4096,
freq_dim=256,
ffn_dim=14336,
num_layers=30,
cross_attn_norm=True,
eps=1e-06,
rope_max_seq_len=1024,
pos_embed_seq_len=None,
attn_mode="torch",
):
super().__init__()
self.patch_size = patch_size
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding_mlp = nn.Linear(
in_channels * patch_size[0] * patch_size[1] * patch_size[2], inner_dim
)
self.action_embedder = nn.Linear(action_dim, inner_dim)
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
self.condition_embedder_action = deepcopy(self.condition_embedder)
self.blocks = nn.ModuleList(
[
WanTransformerBlock(
inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps, attn_mode=attn_mode
)
for _ in range(num_layers)
]
)
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.action_proj_out = nn.Linear(inner_dim, action_dim)
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
# ------------------------------------------------------------------
# KV-cache management for autoregressive streaming inference
# ------------------------------------------------------------------
def clear_cache(self, cache_name):
for block in self.blocks:
block.attn1.clear_cache(cache_name)
def clear_pred_cache(self, cache_name):
for block in self.blocks:
block.attn1.clear_pred_cache(cache_name)
def create_empty_cache(
self,
cache_name,
attn_window,
latent_token_per_chunk,
action_token_per_chunk,
device,
dtype,
batch_size,
):
total_tolen = (attn_window // 2) * latent_token_per_chunk + (
attn_window // 2
) * action_token_per_chunk
for block in self.blocks:
block.attn1.init_kv_cache(
cache_name,
total_tolen,
self.num_attention_heads,
self.attention_head_dim,
device,
dtype,
batch_size,
)
# ------------------------------------------------------------------
# Embedding helpers (shared by train + inference paths)
# ------------------------------------------------------------------
def _input_embed(self, latents, input_type="latent"):
if input_type == "latent":
hidden_states = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self.patch_size[0],
p2=self.patch_size[1],
p3=self.patch_size[2],
)
hidden_states = self.patch_embedding_mlp(hidden_states)
elif input_type == "action":
hidden_states = rearrange(latents, "b c f h w -> b (f h w) c")
hidden_states = self.action_embedder(hidden_states)
elif input_type == "text":
hidden_states = self.condition_embedder.text_embedder(latents)
else:
raise ValueError(f"Unsupported input type: {input_type}")
return hidden_states
def _time_embed(self, timesteps, H, W, dtype, action_mode=False):
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
latent_time_steps = torch.repeat_interleave(
timesteps, (H // pach_scale_h) * (W // pach_scale_w), dim=1
)
current_condition_embedder = (
self.condition_embedder_action if action_mode else self.condition_embedder
)
temb, timestep_proj = current_condition_embedder(latent_time_steps, dtype=dtype)
timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
return temb, timestep_proj
# ------------------------------------------------------------------
# Dual-stream training forward (flow matching). Requires attn_mode='flex'.
# ------------------------------------------------------------------
def forward_train(self, input_dict):
from .wan_flex_attention import FlexAttnFunc
input_dict["latent_dict"]["noisy_latents"] = input_dict["latent_dict"]["noisy_latents"].to(
torch.bfloat16
)
input_dict["latent_dict"]["latent"] = input_dict["latent_dict"]["latent"].to(torch.bfloat16)
input_dict["action_dict"]["noisy_latents"] = input_dict["action_dict"]["noisy_latents"].to(
torch.bfloat16
)
input_dict["action_dict"]["latent"] = input_dict["action_dict"]["latent"].to(torch.bfloat16)
latent_dict = input_dict["latent_dict"]
action_dict = input_dict["action_dict"]
batch_size = latent_dict["noisy_latents"].shape[0]
latent_hidden_states = self._input_embed(latent_dict["noisy_latents"], input_type="latent").flatten(
0, 1
)[None]
action_hidden_states = self._input_embed(action_dict["noisy_latents"], input_type="action").flatten(
0, 1
)[None]
text_hidden_states = self._input_embed(latent_dict["text_emb"], input_type="text")
text_hidden_states = text_hidden_states.flatten(0, 1)[None]
condition_latent_hidden_states = self._input_embed(
latent_dict["latent"], input_type="latent"
).flatten(0, 1)[None]
condition_action_hidden_states = self._input_embed(
action_dict["latent"], input_type="action"
).flatten(0, 1)[None]
hidden_states = torch.cat(
[
latent_hidden_states,
condition_latent_hidden_states,
action_hidden_states,
condition_action_hidden_states,
],
dim=1,
)
latent_grid_id = latent_dict["grid_id"].permute(1, 0, 2).flatten(1)[None]
action_grid_id = action_dict["grid_id"].permute(1, 0, 2).flatten(1)[None]
full_grid_id = torch.cat([latent_grid_id] * 2 + [action_grid_id] * 2, dim=2)
rotary_emb = self.rope(full_grid_id)[:, :, None]
latent_time_steps = torch.cat(
[latent_dict["timesteps"].flatten(0, 1), latent_dict["cond_timesteps"].flatten(0, 1)]
)[None]
action_time_steps = torch.cat(
[action_dict["timesteps"].flatten(0, 1), action_dict["cond_timesteps"].flatten(0, 1)]
)[None]
latent_temb, latent_timestep_proj = self._time_embed(
latent_time_steps,
latent_dict["noisy_latents"].shape[-2],
latent_dict["noisy_latents"].shape[-1],
dtype=hidden_states.dtype,
action_mode=False,
)
action_temb, action_timestep_proj = self._time_embed(
action_time_steps,
action_dict["noisy_latents"].shape[-2],
action_dict["noisy_latents"].shape[-1],
dtype=hidden_states.dtype,
action_mode=True,
)
temb = torch.cat([latent_temb, action_temb], dim=1)
timestep_proj = torch.cat([latent_timestep_proj, action_timestep_proj], dim=1)
total_length = hidden_states.shape[1]
padded_length = (128 - total_length % 128) % 128
hidden_states = F.pad(hidden_states, (0, 0, 0, padded_length))
rotary_emb = F.pad(rotary_emb, (0, 0, 0, 0, 0, padded_length))
temb = F.pad(temb, (0, 0, 0, padded_length))
timestep_proj = F.pad(timestep_proj, (0, 0, 0, 0, 0, padded_length))
split_list = [
latent_hidden_states.shape[1],
condition_latent_hidden_states.shape[1],
action_hidden_states.shape[1],
condition_action_hidden_states.shape[1],
padded_length,
]
FlexAttnFunc.init_mask(
latent_dict["noisy_latents"].shape,
action_dict["noisy_latents"].shape,
padded_length,
input_dict["chunk_size"],
window_size=input_dict["window_size"],
patch_size=self.patch_size,
device=hidden_states.device,
)
for block in self.blocks:
hidden_states = block(
hidden_states, text_hidden_states, timestep_proj, rotary_emb, update_cache=False
)
temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
shift, scale = rearrange(temb_scale_shift_table, "b l n c -> b n l c").chunk(2, dim=1)
shift = shift.to(hidden_states.device).squeeze(1)
scale = scale.to(hidden_states.device).squeeze(1)
hidden_states = (self.norm_out(hidden_states.float()) * (1.0 + scale) + shift).type_as(hidden_states)
latent_hidden_states, _, action_hidden_states, _, _ = torch.split(hidden_states, split_list, dim=1)
latent_hidden_states = self.proj_out(latent_hidden_states)
latent_hidden_states = rearrange(
latent_hidden_states, "1 (b l) (n c) -> b (l n) c", n=math.prod(self.patch_size), b=batch_size
)
action_hidden_states = self.action_proj_out(action_hidden_states)
action_hidden_states = rearrange(action_hidden_states, "1 (b l) c -> b l c", b=batch_size)
return latent_hidden_states, action_hidden_states
# ------------------------------------------------------------------
# Single-stream inference forward (one denoising step for one stream)
# ------------------------------------------------------------------
def forward(
self,
input_dict,
update_cache=0,
cache_name="pos",
action_mode=False,
train_mode=False,
):
if train_mode:
return self.forward_train(input_dict)
if action_mode: # action input emb
latent_hidden_states = rearrange(input_dict["noisy_latents"], "b c f h w -> b (f h w) c")
latent_hidden_states = self.action_embedder(latent_hidden_states) # B L1 C
else: # latent input emb
latent_hidden_states = rearrange(
input_dict["noisy_latents"],
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self.patch_size[0],
p2=self.patch_size[1],
p3=self.patch_size[2],
)
latent_hidden_states = self.patch_embedding_mlp(latent_hidden_states)
text_hidden_states = self.condition_embedder.text_embedder(input_dict["text_emb"]) # B L2 C
latent_grid_id = input_dict["grid_id"]
rotary_emb = self.rope(latent_grid_id)[:, :, None] # 1 L 1 C
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
latent_time_steps = torch.repeat_interleave(
input_dict["timesteps"],
(input_dict["noisy_latents"].shape[-2] // pach_scale_h)
* (input_dict["noisy_latents"].shape[-1] // pach_scale_w),
dim=1,
) # L
current_condition_embedder = (
self.condition_embedder_action if action_mode else self.condition_embedder
)
temb, timestep_proj = current_condition_embedder(latent_time_steps, dtype=latent_hidden_states.dtype)
timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
for block in self.blocks:
latent_hidden_states = block(
latent_hidden_states,
text_hidden_states,
timestep_proj,
rotary_emb,
update_cache=update_cache,
cache_name=cache_name,
)
temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
shift, scale = rearrange(temb_scale_shift_table, "b l n c -> b n l c").chunk(2, dim=1)
shift = shift.to(latent_hidden_states.device).squeeze(1)
scale = scale.to(latent_hidden_states.device).squeeze(1)
latent_hidden_states = (self.norm_out(latent_hidden_states.float()) * (1.0 + scale) + shift).type_as(
latent_hidden_states
)
if action_mode:
latent_hidden_states = self.action_proj_out(latent_hidden_states)
else:
latent_hidden_states = self.proj_out(latent_hidden_states)
latent_hidden_states = rearrange(
latent_hidden_states, "b l (n c) -> b (l n) c", n=math.prod(self.patch_size)
)
return latent_hidden_states
@@ -0,0 +1,56 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grid-id / patch utilities for the LingBot-VA autoregressive inference loop.
Vendored verbatim from the upstream LingBot-VA repository
(https://github.com/Robbyant/lingbot-va, ``wan_va/utils/utils.py``).
"""
import torch
__all__ = ["get_mesh_id", "data_seq_to_patch"]
def data_seq_to_patch(patch_size, data_seq, latent_num_frames, latent_height, latent_width, batch_size=1):
"""Reshape a flattened patch sequence back into a ``(B, C, F, H, W)`` latent grid."""
p_t, p_h, p_w = patch_size
post_patch_num_frames = latent_num_frames // p_t
post_patch_height = latent_height // p_h
post_patch_width = latent_width // p_w
data_patch = data_seq.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
data_patch = data_patch.permute(0, 7, 1, 4, 2, 5, 3, 6)
data_patch = data_patch.flatten(6, 7).flatten(4, 5).flatten(2, 3)
return data_patch
def get_mesh_id(f, h, w, t, f_w=1, f_shift=0, action=False):
"""Build the (frame, height, width, stream) grid ids used to index the rotary embedding."""
f_idx = torch.arange(f_shift, f + f_shift) * f_w
h_idx = torch.arange(h)
w_idx = torch.arange(w)
ff, hh, ww = torch.meshgrid(f_idx, h_idx, w_idx, indexing="ij")
if action:
ff_offset = (torch.ones([h]).cumsum(0) / (h + 1)).view(1, -1, 1)
ff = ff + ff_offset
hh = torch.ones_like(hh) * -1
ww = torch.ones_like(ww) * -1
grid_id = torch.cat([ff.unsqueeze(0), hh.unsqueeze(0), ww.unsqueeze(0)], dim=0).flatten(1)
grid_id = torch.cat([grid_id, torch.full_like(grid_id[:1], t)], dim=0)
return grid_id
+120
View File
@@ -0,0 +1,120 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Thin helpers around the stock diffusers ``AutoencoderKLWan`` (Wan2.2, ``z_dim=48``).
The VAE class itself is NOT vendored it lives in ``diffusers>=0.36``. This module
provides:
* loaders for the VAE / text encoder / tokenizer / transformer sub-checkpoints,
* the streaming-encoder wrapper used for autoregressive frame-by-frame VAE encoding
(it caches the causal-conv state across chunks),
* latent (de)normalization helpers using the VAE's ``latents_mean`` / ``latents_std``.
Vendored and adapted from ``wan_va/modules/utils.py`` upstream.
"""
import torch
__all__ = [
"WanVAEStreamingWrapper",
"load_vae",
"load_text_encoder",
"load_tokenizer",
"normalize_latents",
"denormalize_latents",
"patchify",
]
def load_vae(vae_path, torch_dtype, torch_device):
from diffusers import AutoencoderKLWan
vae = AutoencoderKLWan.from_pretrained(vae_path, torch_dtype=torch_dtype)
return vae.to(torch_device)
def load_text_encoder(text_encoder_path, torch_dtype, torch_device):
from transformers import UMT5EncoderModel
text_encoder = UMT5EncoderModel.from_pretrained(text_encoder_path, torch_dtype=torch_dtype)
return text_encoder.to(torch_device)
def load_tokenizer(tokenizer_path):
from transformers import T5TokenizerFast
return T5TokenizerFast.from_pretrained(tokenizer_path)
def patchify(x, patch_size):
if patch_size is None or patch_size == 1:
return x
batch_size, channels, frames, height, width = x.shape
x = x.view(
batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size
)
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
x = x.view(
batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size
)
return x
def normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
) -> torch.Tensor:
"""Apply ``(x - mean) * std`` channel-wise (note: upstream passes ``1/std`` as ``latents_std``)."""
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
latents = ((latents.float() - latents_mean) * latents_std).to(latents)
return latents
def denormalize_latents(latents: torch.Tensor, latents_mean, latents_std, z_dim) -> torch.Tensor:
"""Inverse of the normalization applied at encode time, for VAE decoding of predicted latents."""
mean = torch.tensor(latents_mean).view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
inv_std = 1.0 / torch.tensor(latents_std).view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
return latents / inv_std + mean
class WanVAEStreamingWrapper:
"""Wraps an ``AutoencoderKLWan`` encoder to support causal streaming encoding across chunks."""
def __init__(self, vae_model):
self.vae = vae_model
self.encoder = vae_model.encoder
self.quant_conv = vae_model.quant_conv
if hasattr(self.vae, "_cached_conv_counts"):
self.enc_conv_num = self.vae._cached_conv_counts["encoder"]
else:
count = 0
for m in self.encoder.modules():
if m.__class__.__name__ == "WanCausalConv3d":
count += 1
self.enc_conv_num = count
self.clear_cache()
def clear_cache(self):
self.feat_cache = [None] * self.enc_conv_num
def encode_chunk(self, x_chunk):
if hasattr(self.vae.config, "patch_size") and self.vae.config.patch_size is not None:
x_chunk = patchify(x_chunk, self.vae.config.patch_size)
feat_idx = [0]
out = self.encoder(x_chunk, feat_cache=self.feat_cache, feat_idx=feat_idx)
enc = self.quant_conv(out)
return enc
+53
View File
@@ -105,6 +105,7 @@ def rollout(
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
predicted_frames_callback: Callable[[PreTrainedPolicy], None] | None = None,
) -> dict:
"""Run a batched policy rollout once through a batch of environments.
@@ -134,6 +135,9 @@ def rollout(
are returned optionally because they typically take more memory to cache. Defaults to False.
render_callback: Optional rendering callback to be used after the environments are reset, and after
every step.
predicted_frames_callback: Optional callback invoked after every ``select_action`` with the policy
itself. World-model policies (e.g. LingBot-VA) stash their decoded predicted video frames on
``policy.last_predicted_frames``; this lets the caller collect them to save predicted-video MP4s.
Returns:
The dictionary described above.
"""
@@ -184,6 +188,8 @@ def rollout(
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
if predicted_frames_callback is not None:
predicted_frames_callback(policy)
action = postprocessor(action)
action_transition = {ACTION: action}
@@ -273,6 +279,7 @@ def eval_policy(
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
save_predicted_video: bool = False,
) -> dict:
"""
Args:
@@ -291,6 +298,11 @@ def eval_policy(
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
# World-model policies (e.g. LingBot-VA) opt into predicted-video saving via their config.
save_predicted_video = save_predicted_video or bool(
getattr(getattr(policy, "config", None), "save_predicted_video", False)
)
if not isinstance(policy, PreTrainedPolicy):
exc = ValueError(
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
@@ -334,6 +346,21 @@ def eval_policy(
if max_episodes_rendered > 0:
video_paths: list[str] = []
if save_predicted_video:
if not videos_dir:
raise ValueError("If save_predicted_video is True, videos_dir must be provided.")
predicted_video_paths: list[str] = []
n_predicted_rendered = 0
# Collects the policy's decoded predicted-video frames across a rollout (world-model policies only).
def collect_predicted_frames(policy: PreTrainedPolicy):
frames = getattr(policy, "last_predicted_frames", None)
if frames is not None:
pred_frames.append(
np.asarray(frames.detach().to("cpu")) if hasattr(frames, "detach") else np.asarray(frames)
)
policy.last_predicted_frames = None
if return_episode_data:
episode_data: dict | None = None
@@ -345,6 +372,9 @@ def eval_policy(
if max_episodes_rendered > 0:
ep_frames: list[np.ndarray] = []
if save_predicted_video:
pred_frames: list[np.ndarray] = []
if start_seed is None:
seeds = None
else:
@@ -361,6 +391,7 @@ def eval_policy(
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
predicted_frames_callback=collect_predicted_frames if save_predicted_video else None,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after
@@ -426,6 +457,25 @@ def eval_policy(
threads.append(thread)
n_episodes_rendered += 1
# Maybe save the policy's predicted (imagined) video for this batch's rollout.
if save_predicted_video and len(pred_frames) > 0:
# pred_frames is a list of [F, H, W, C] uint8 stacks emitted on chunk refills; concat over time.
predicted_video = np.concatenate(pred_frames, axis=0)
videos_dir.mkdir(parents=True, exist_ok=True)
predicted_video_path = videos_dir / f"pred_episode_{n_predicted_rendered}.mp4"
predicted_video_paths.append(str(predicted_video_path))
thread = threading.Thread(
target=write_video,
args=(
str(predicted_video_path),
predicted_video,
env.unwrapped.metadata["render_fps"],
),
)
thread.start()
threads.append(thread)
n_predicted_rendered += 1
progbar.set_postfix(
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
)
@@ -469,6 +519,9 @@ def eval_policy(
if max_episodes_rendered > 0:
info["video_paths"] = video_paths
if save_predicted_video:
info["predicted_video_paths"] = predicted_video_paths
return info
@@ -0,0 +1,83 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES
def make_config(**overrides) -> LingBotVAConfig:
kwargs = {"device": "cpu"}
kwargs.update(overrides)
return LingBotVAConfig(**kwargs)
def test_registered_in_choice_registry() -> None:
assert "lingbot_va" in PreTrainedConfig.get_known_choices()
assert PreTrainedConfig.get_choice_class("lingbot_va") is LingBotVAConfig
def test_type_property() -> None:
assert make_config().type == "lingbot_va"
def test_chunk_size_and_action_steps() -> None:
cfg = make_config(frame_chunk_size=4, action_per_frame=4)
assert cfg.chunk_size == 16
assert cfg.n_action_steps == 16
assert cfg.action_delta_indices == list(range(16))
assert cfg.observation_delta_indices is None
assert cfg.reward_delta_indices is None
def test_optimizer_and_scheduler_presets() -> None:
cfg = make_config()
opt = cfg.get_optimizer_preset()
assert opt.lr == cfg.optimizer_lr
sched = cfg.get_scheduler_preset()
assert sched.num_warmup_steps == cfg.scheduler_warmup_steps
def test_validate_features_sets_action_feature() -> None:
cfg = make_config()
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
cfg.output_features = {}
cfg.validate_features()
assert ACTION in cfg.output_features
assert cfg.output_features[ACTION].shape == (len(cfg.used_action_channel_ids),)
def test_validate_features_no_visual_raises() -> None:
cfg = make_config()
cfg.input_features = {}
cfg.output_features = {}
with pytest.raises(ValueError, match="at least one visual input feature"):
cfg.validate_features()
def test_invalid_attn_mode_raises() -> None:
with pytest.raises(ValueError, match="attn_mode"):
make_config(attn_mode="banana")
def test_quantile_length_mismatch_raises() -> None:
with pytest.raises(ValueError, match="action_q01"):
make_config(used_action_channel_ids=[0, 1, 2], action_q01=[0.0, 0.0], action_q99=[1.0, 1.0, 1.0])
+52
View File
@@ -0,0 +1,52 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
from lerobot.policies.factory import make_policy_config
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
def test_make_policy_config_returns_lingbot_va() -> None:
cfg = make_policy_config("lingbot_va", device="cpu")
assert isinstance(cfg, LingBotVAConfig)
def test_get_policy_class_resolves_lazily() -> None:
# Importing the policy class pulls in diffusers (Wan2.2 stack); skip if unavailable.
pytest.importorskip("diffusers")
pytest.importorskip("transformers")
from lerobot.policies.factory import get_policy_class
cls = get_policy_class("lingbot_va")
assert cls.name == "lingbot_va"
assert cls.config_class is LingBotVAConfig
def test_convert_build_config_libero() -> None:
pytest.importorskip("diffusers")
from lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints import build_config
cfg = build_config("libero", wan_pretrained_path="dummy/path", dtype="float32")
assert cfg.height == 128 and cfg.width == 128
assert cfg.used_action_channel_ids == list(range(7))
# validate_features (called inside build_config) must have populated the action feature.
from lerobot.utils.constants import ACTION
assert cfg.output_features[ACTION].shape == (7,)
assert len(cfg.obs_cam_keys) == 2
+69
View File
@@ -0,0 +1,69 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pure-torch unit tests for the vendored LingBot-VA helper modules (no diffusers needed)."""
from __future__ import annotations
import torch
from lerobot.policies.lingbot_va.schedulers import FlowMatchScheduler
from lerobot.policies.lingbot_va.wan_utils import data_seq_to_patch, get_mesh_id
def test_flow_match_scheduler_timesteps_monotone_decreasing() -> None:
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
sch.set_timesteps(20)
assert sch.timesteps.shape == (20,)
diffs = sch.timesteps[1:] - sch.timesteps[:-1]
assert torch.all(diffs <= 0) # decreasing
def test_flow_match_scheduler_step_preserves_shape() -> None:
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
sch.set_timesteps(20)
sample = torch.zeros(1, 48, 4, 8, 16)
out = sch.step(torch.ones_like(sample), sch.timesteps[0], sample)
assert out.shape == sample.shape
def test_flow_match_scheduler_add_noise() -> None:
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
sch.set_timesteps(20)
sample = torch.randn(1, 48, 4, 8, 16)
noise = torch.randn_like(sample)
noisy = sch.add_noise(sample, noise, sch.timesteps[:4], t_dim=2)
assert noisy.shape == sample.shape
def test_get_mesh_id_latent_shape() -> None:
grid = get_mesh_id(4, 8, 16, 0, 1, 0)
assert grid.shape == (4, 4 * 8 * 16) # (f, h, w, stream) x tokens
def test_get_mesh_id_action_shape() -> None:
grid = get_mesh_id(4, 4, 1, 1, 1, 0, action=True)
assert grid.shape == (4, 4 * 4 * 1)
# Action rows for h/w are sentinel -1.
assert torch.all(grid[1] < 0)
assert torch.all(grid[2] < 0)
def test_data_seq_to_patch_roundtrip_shape() -> None:
b, f, h, w, c = 1, 4, 8, 16, 48
seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c)
out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b)
assert out.shape == (b, c, f, h, w)
@@ -0,0 +1,81 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.policies.lingbot_va.processor_lingbot_va import (
LingBotVAActionUnnormalizeStep,
make_lingbot_va_pre_post_processors,
)
from lerobot.utils.constants import (
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def _make_config() -> LingBotVAConfig:
cfg = LingBotVAConfig(device="cpu")
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
cfg.output_features = {}
cfg.validate_features()
return cfg
def test_action_unnormalize_inverts_quantile_norm() -> None:
q01 = [-1.0, -0.5, 0.0]
q99 = [1.0, 0.5, 2.0]
step = LingBotVAActionUnnormalizeStep(action_q01=q01, action_q99=q99)
# Forward (the policy-side) quantile normalization: (x - q01) / (q99 - q01 + eps) * 2 - 1.
q01_t = torch.tensor(q01)
q99_t = torch.tensor(q99)
raw = torch.tensor([[0.3, 0.1, 1.0]])
normed = (raw - q01_t) / (q99_t - q01_t + 1e-6) * 2.0 - 1.0
recovered = step.action(normed)
assert torch.allclose(recovered, raw, atol=1e-4)
def test_action_unnormalize_config_roundtrip() -> None:
step = LingBotVAActionUnnormalizeStep(action_q01=[0.0, 1.0], action_q99=[2.0, 3.0])
cfg = step.get_config()
assert cfg == {"action_q01": [0.0, 1.0], "action_q99": [2.0, 3.0]}
rebuilt = LingBotVAActionUnnormalizeStep(**cfg)
assert rebuilt.action_q01 == step.action_q01
assert rebuilt.action_q99 == step.action_q99
def test_make_pre_post_processors_names_and_steps() -> None:
cfg = _make_config()
pre, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
assert pre.name == POLICY_PREPROCESSOR_DEFAULT_NAME
assert post.name == POLICY_POSTPROCESSOR_DEFAULT_NAME
# The postprocessor must contain the dedicated quantile unnormalize step.
assert any(isinstance(s, LingBotVAActionUnnormalizeStep) for s in post.steps)
def test_postprocessor_applies_unnormalization() -> None:
cfg = _make_config()
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
# A normalized action of all -1 should map back to q01.
normed = torch.full((1, len(cfg.used_action_channel_ids)), -1.0)
out = post(normed)
assert torch.allclose(out, torch.tensor(cfg.action_q01).unsqueeze(0), atol=1e-4)