mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
feat(policies): add LingBot-VA autoregressive video-action world model
Port the LingBot-VA policy (Wan2.2 dual-stream video+action world model) into LeRobot, following the EO-1 / VLA-JEPA conventions. Covers inference, checkpoint conversion, and predicted-video saving (training is deferred to a follow-up PR). - Vendored Wan transformer/attention/flex/VAE/scheduler modules (key names preserved for near-identity conversion); torch SDPA default, flashattn/flex lazy-guarded. - LingBotVAConfig (registered "lingbot_va") + processor with fixed-quantile action unnormalization; full dual-stream sampling loop with CFG, two flow-matching schedulers and KV cache, mapped onto select_action with observed-keyframe feedback. - convert_lingbot_va_checkpoints.py (libero/robotwin variants): bundles the ~5B transformer, lazy-pulls the frozen VAE+UMT5 from the source repo. - Predicted-video plumbing in lerobot_eval (predicted_frames_callback; opt-in via --policy.save_predicted_video) and ConstantWithWarmupSchedulerConfig. - pyproject: widen diffusers-dep to <0.37, add lingbot_va + imageio-dep extras, add lingbot_va and (missing) eo1 to `all`. - Factory + policies/__init__ wiring, docs page + toctree, and tests. Note: the LIBERO success-rate correctness gate must be validated on a CUDA GPU with the converted checkpoint. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -67,6 +67,8 @@
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: lingbot_va
|
||||
title: LingBot-VA
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
# LingBot-VA
|
||||
|
||||
LingBot-VA is an **autoregressive video-action world-model policy** built on the **Wan2.2**
|
||||
video-diffusion stack. It interleaves, in one autoregressive sequence, the prediction of
|
||||
future **video latents** and **robot actions** ("VA" = Video-Action). The LeRobot
|
||||
integration wires LingBot-VA into the standard training, evaluation and processor
|
||||
interfaces.
|
||||
|
||||
## Model Overview
|
||||
|
||||
LingBot-VA is a **dual-stream "mixture-of-transformers"**: a video/latent stream
|
||||
(`patch_embedding_mlp → blocks → proj_out`) and an action stream
|
||||
(`action_embedder → blocks → action_proj_out`) share the same 30 transformer blocks and
|
||||
text conditioning. Actions are produced by the dedicated `action_proj_out` head — they are
|
||||
**not** decoded from predicted pixels, though video and action are co-trained.
|
||||
|
||||
| Component | Class | Role |
|
||||
|---|---|---|
|
||||
| DiT backbone (trainable) | `WanTransformer3DModel` | ~5B-param dual-stream transformer (the only weights stored in the LeRobot checkpoint). |
|
||||
| VAE (frozen) | `AutoencoderKLWan` | Wan2.2 VAE, `z_dim=48`. Lazy-pulled from the source repo. |
|
||||
| Text encoder (frozen) | `UMT5EncoderModel` | UMT5-XXL, `d_model=4096`. Lazy-pulled from the source repo. |
|
||||
|
||||
At inference the policy runs an autoregressive loop per chunk: it denoises the video-latent
|
||||
stream (CFG, ~20 steps) and the action stream (~50 steps) with two independent
|
||||
flow-matching schedulers, maintaining a KV cache across chunks. Real observed keyframes are
|
||||
fed back into the KV cache as the chunk is executed (closed-loop world modeling).
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=lingbot_va` configuration through LeRobot.
|
||||
- Checkpoint conversion from the released HuggingFace checkpoints.
|
||||
- Autoregressive dual-stream inference behind the standard `select_action` interface
|
||||
(single-environment eval, `--eval.batch_size=1`).
|
||||
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
|
||||
- Evaluation with `lerobot-eval` on the LIBERO benchmark.
|
||||
|
||||
Training (the flow-matching dual-stream loss + latent dataset) is part of a follow-up
|
||||
training port and is not yet wired into `lerobot-train`.
|
||||
|
||||
## Installation
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install the LingBot-VA extra (brings in `diffusers>=0.36` for the Wan2.2 stack):
|
||||
|
||||
```bash
|
||||
pip install -e ".[lingbot_va]"
|
||||
# For LIBERO evaluation (Linux only):
|
||||
pip install -e ".[lingbot_va,libero]"
|
||||
```
|
||||
|
||||
## Checkpoint Conversion
|
||||
|
||||
The released checkpoints are diffusers-style directories
|
||||
(`robbyant/lingbot-va-base`, `robbyant/lingbot-va-posttrain-robotwin`,
|
||||
`robbyant/lingbot-va-posttrain-libero-long`). Convert one to LeRobot format with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints \
|
||||
--checkpoint robbyant/lingbot-va-posttrain-libero-long \
|
||||
--variant libero \
|
||||
--output_dir outputs/lingbot_va_libero_long
|
||||
```
|
||||
|
||||
**Packaging:** only the trainable ~5B transformer is stored in the LeRobot
|
||||
`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are **lazily pulled** from
|
||||
`config.wan_pretrained_path` at load time (defaults to the source repo). Pass
|
||||
`--bundle-frozen` to copy those sub-folders next to the converted checkpoint instead.
|
||||
|
||||
Run conversion on a Linux machine with a CUDA GPU and enough RAM/VRAM to materialize the
|
||||
transformer.
|
||||
|
||||
## Evaluation (LIBERO)
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=outputs/lingbot_va_libero_long \
|
||||
--env.type=libero --env.task=libero_10 \
|
||||
--eval.n_episodes=50 --eval.batch_size=1 \
|
||||
--output_dir=outputs/eval/lingbot_va_libero
|
||||
```
|
||||
|
||||
LingBot-VA's streaming inference (KV cache + observed-keyframe feedback) is implemented for
|
||||
single-environment eval; use `--eval.batch_size=1`.
|
||||
|
||||
### Saving predicted (imagined) videos
|
||||
|
||||
Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicted video
|
||||
latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos.
|
||||
The same flag works for the periodic eval during `lerobot-train`.
|
||||
|
||||
## Inference Hyperparameters (LIBERO)
|
||||
|
||||
| Key | Value |
|
||||
|---|---|
|
||||
| height × width | 128 × 128 |
|
||||
| cameras | `observation.images.image` (agentview), `observation.images.image2` (eye-in-hand) |
|
||||
| action channels used | 0–6 (7-DoF arm + gripper) |
|
||||
| action_per_frame / frame_chunk_size | 4 / 4 |
|
||||
| attn_window | 30 |
|
||||
| video / action denoising steps | 20 / 50 |
|
||||
| guidance_scale / action_guidance_scale | 5 / 1 |
|
||||
| snr_shift / action_snr_shift | 5.0 / 0.05 |
|
||||
|
||||
These are the defaults of `LingBotVAConfig`; override any of them via `--policy.<name>=...`.
|
||||
|
||||
## Notes & Limitations
|
||||
|
||||
- **Correctness gate:** matching the upstream LIBERO success rate requires validating the
|
||||
converted checkpoint on a GPU and tensor-diffing intermediate activations against the
|
||||
upstream implementation. The most sensitive parts are the action quantile normalization,
|
||||
the camera ordering, the `action_per_frame`/`frame_chunk_size` alignment, and `attn_mode`.
|
||||
- **Attention backend:** inference uses the `torch` SDPA backend (always available). The
|
||||
`flashattn` and `flex` backends are optional; `flex` is only needed for training.
|
||||
- **Model size:** the DiT is ~5B params and the frozen VAE+UMT5 add ~20 GB; inference needs
|
||||
roughly 18–24 GB of VRAM.
|
||||
|
||||
## License
|
||||
|
||||
LingBot-VA is released under Apache-2.0. See the
|
||||
[upstream repository](https://github.com/Robbyant/lingbot-va).
|
||||
+11
-1
@@ -146,7 +146,8 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
|
||||
diffusers-dep = ["diffusers>=0.27.2,<0.37.0"]
|
||||
imageio-dep = ["imageio[ffmpeg]>=2.34.0,<3.0.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||
pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
@@ -218,6 +219,10 @@ xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
# LingBot-VA needs the Wan2.2 stack (AutoencoderKLWan z_dim=48 + WanTransformer3DModel config schema),
|
||||
# which only exists in diffusers>=0.36. Pin the floor explicitly so a standalone `lerobot[lingbot_va]`
|
||||
# install can't resolve to a pre-Wan2.2 diffusers via the looser diffusers-dep floor.
|
||||
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -284,6 +289,8 @@ all = [
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[eo1]",
|
||||
"lerobot[lingbot_va]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
@@ -375,6 +382,9 @@ ignore = [
|
||||
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
||||
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
# Vendored Wan2.2 / LingBot-VA model code uses tensor-dimension names (B, F, H, W) and `F` for
|
||||
# torch.nn.functional; keep the upstream naming to make diffing against upstream tractable.
|
||||
"src/lerobot/policies/lingbot_va/**" = ["N803", "N806", "N812", "SIM102"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
|
||||
@@ -83,6 +83,28 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
||||
@dataclass
|
||||
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Linear warmup followed by a constant learning rate.
|
||||
|
||||
Mirrors the ``warmup_constant_lambda`` used by LingBot-VA (upstream ``wan_va/train.py``):
|
||||
the LR ramps linearly from 0 to the peak over ``num_warmup_steps`` steps, then stays flat.
|
||||
"""
|
||||
|
||||
num_warmup_steps: int = 1000
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
warmup_steps = self.num_warmup_steps or 0
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / float(max(1, warmup_steps))
|
||||
return 1.0
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
|
||||
@@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig as LingBotVAConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
@@ -44,6 +45,7 @@ __all__ = [
|
||||
"EO1Config",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"LingBotVAConfig",
|
||||
"MolmoAct2Config",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
|
||||
@@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
@@ -162,6 +163,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
elif name == "lingbot_va":
|
||||
from .lingbot_va.modeling_lingbot_va import LingBotVAPolicy
|
||||
|
||||
return LingBotVAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -218,6 +223,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
elif policy_type == "lingbot_va":
|
||||
return LingBotVAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -448,6 +455,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, LingBotVAConfig):
|
||||
from .lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
|
||||
processors = make_lingbot_va_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
../../../../docs/source/lingbot_va.mdx
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# NOTE: ``LingBotVAPolicy`` (and the Wan transformer it owns) imports ``diffusers`` as a
|
||||
# hard dependency at class-definition time (it subclasses diffusers' ModelMixin/ConfigMixin).
|
||||
# To keep base ``import lerobot`` working without the optional ``lingbot_va`` extra, the
|
||||
# policy is exposed lazily via module ``__getattr__`` — the heavy import only happens when
|
||||
# ``LingBotVAPolicy`` is actually accessed (mirroring the lazy import in policies/factory.py).
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
from .processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
|
||||
__all__ = ["LingBotVAConfig", "LingBotVAPolicy", "make_lingbot_va_pre_post_processors"]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "LingBotVAPolicy":
|
||||
from .modeling_lingbot_va import LingBotVAPolicy
|
||||
|
||||
return LingBotVAPolicy
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@@ -0,0 +1,198 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for the LingBot-VA policy.
|
||||
|
||||
LingBot-VA is an autoregressive video-action world-model policy built on the Wan2.2
|
||||
video-diffusion stack. It interleaves prediction of future video latents and robot
|
||||
actions in a single dual-stream transformer. See ``docs/source/lingbot_va.mdx`` and the
|
||||
upstream repository (https://github.com/Robbyant/lingbot-va).
|
||||
|
||||
Defaults below match the upstream LIBERO configuration (``wan_va/configs/va_libero_cfg.py``)
|
||||
and the ``transformer/config.json`` of the released checkpoints.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
# Upstream LIBERO action-normalization quantiles (single 7-DoF arm + gripper).
|
||||
# Verbatim from wan_va/configs/va_libero_cfg.py (channels 0-6 of a 30-dim action space).
|
||||
LIBERO_ACTION_Q01 = [
|
||||
-0.6589285731315613,
|
||||
-0.84375,
|
||||
-0.9375,
|
||||
-0.12107142806053162,
|
||||
-0.15964286029338837,
|
||||
-0.26571428775787354,
|
||||
-1.0,
|
||||
]
|
||||
LIBERO_ACTION_Q99 = [
|
||||
0.8999999761581421,
|
||||
0.8544642925262451,
|
||||
0.9375,
|
||||
0.17142857611179352,
|
||||
0.1842857152223587,
|
||||
0.34392857551574707,
|
||||
1.0,
|
||||
]
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("lingbot_va")
|
||||
@dataclass
|
||||
class LingBotVAConfig(PreTrainedConfig):
|
||||
"""Configuration for the native LingBot-VA policy integration in LeRobot."""
|
||||
|
||||
# ── Wan transformer architecture (from transformer/config.json) ──
|
||||
patch_size: tuple[int, int, int] = (1, 2, 2)
|
||||
num_attention_heads: int = 24
|
||||
attention_head_dim: int = 128
|
||||
in_channels: int = 48
|
||||
out_channels: int = 48
|
||||
action_dim: int = 30
|
||||
text_dim: int = 4096
|
||||
freq_dim: int = 256
|
||||
ffn_dim: int = 14336
|
||||
num_layers: int = 30
|
||||
cross_attn_norm: bool = True
|
||||
eps: float = 1e-6
|
||||
rope_max_seq_len: int = 1024
|
||||
# "flex" is supported for training only and needs a recent torch build. Inference uses
|
||||
# "torch" SDPA (always available) or, optionally, "flashattn".
|
||||
attn_mode: str = "torch"
|
||||
|
||||
# ── Frozen sub-models (VAE + UMT5 text encoder + tokenizer) ──
|
||||
# These heavy frozen weights (~20 GB) are NOT bundled into the LeRobot safetensors
|
||||
# checkpoint (only the trainable ~5B transformer is). They are lazily pulled from this
|
||||
# HF repo / local directory at policy-init time. The directory must contain the
|
||||
# diffusers-style ``vae/``, ``text_encoder/`` and ``tokenizer/`` sub-folders.
|
||||
wan_pretrained_path: str = "robbyant/lingbot-va-posttrain-libero-long"
|
||||
# dtype used for the transformer / VAE / text-encoder weights at inference.
|
||||
dtype: str = "bfloat16" # one of "bfloat16", "float16", "float32"
|
||||
|
||||
# ── Observation cameras (order matters: latents are concatenated on width) ──
|
||||
# Defaults match the LIBERO env feature keys (agentview -> image, eye-in-hand -> image2).
|
||||
obs_cam_keys: list[str] = field(
|
||||
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
|
||||
)
|
||||
|
||||
# ── Inference hyperparameters (LIBERO defaults) ──
|
||||
n_obs_steps: int = 1
|
||||
height: int = 128
|
||||
width: int = 128
|
||||
action_per_frame: int = 4
|
||||
frame_chunk_size: int = 4
|
||||
attn_window: int = 30
|
||||
num_inference_steps: int = 20
|
||||
video_exec_step: int = -1
|
||||
action_num_inference_steps: int = 50
|
||||
guidance_scale: float = 5.0
|
||||
action_guidance_scale: float = 1.0
|
||||
snr_shift: float = 5.0
|
||||
action_snr_shift: float = 0.05
|
||||
max_sequence_length: int = 512 # UMT5 prompt length
|
||||
|
||||
# Subset of the 30-d action space actually used by the benchmark (LIBERO = 7-DoF).
|
||||
used_action_channel_ids: list[int] = field(default_factory=lambda: list(range(7)))
|
||||
# Fixed quantiles for action (un)normalization on the *used* channels.
|
||||
action_q01: list[float] = field(default_factory=lambda: list(LIBERO_ACTION_Q01))
|
||||
action_q99: list[float] = field(default_factory=lambda: list(LIBERO_ACTION_Q99))
|
||||
|
||||
# Opt-in: VAE-decode the predicted video latents and stash them on
|
||||
# ``self.last_predicted_frames`` so eval/train can save predicted-video MP4s.
|
||||
save_predicted_video: bool = False
|
||||
|
||||
# ── Normalization (handled internally / via custom steps, hence IDENTITY here) ──
|
||||
# Images are scaled to [-1, 1] and VAE-encoded inside the policy; actions are
|
||||
# quantile-(un)normalized by dedicated processor steps using the fixed quantiles above.
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
# ── Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py) ──
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
scheduler_warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.attn_mode not in ("torch", "flashattn", "flex"):
|
||||
raise ValueError(f"attn_mode must be one of 'torch', 'flashattn', 'flex'; got {self.attn_mode!r}")
|
||||
if len(self.action_q01) != len(self.used_action_channel_ids) or len(self.action_q99) != len(
|
||||
self.used_action_channel_ids
|
||||
):
|
||||
raise ValueError(
|
||||
"action_q01 / action_q99 must each have one entry per used_action_channel_ids "
|
||||
f"({len(self.used_action_channel_ids)}); got {len(self.action_q01)} / {len(self.action_q99)}."
|
||||
)
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
"""Number of single-step actions produced per autoregressive chunk."""
|
||||
return self.frame_chunk_size * self.action_per_frame
|
||||
|
||||
@property
|
||||
def n_action_steps(self) -> int:
|
||||
"""Number of actions executed before refilling (the whole chunk)."""
|
||||
return self.chunk_size
|
||||
|
||||
def validate_features(self) -> None:
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"LingBot-VA requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
if ACTION not in self.output_features:
|
||||
self.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(len(self.used_action_channel_ids),)
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
# Upstream uses a linear warmup followed by a constant LR (warmup_constant_lambda).
|
||||
from lerobot.optim.schedulers import ConstantWithWarmupSchedulerConfig
|
||||
|
||||
return ConstantWithWarmupSchedulerConfig(num_warmup_steps=self.scheduler_warmup_steps)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert a released LingBot-VA HuggingFace checkpoint to LeRobot format.
|
||||
|
||||
The released checkpoints are diffusers-style directories with ``transformer/``, ``vae/``,
|
||||
``text_encoder/`` and ``tokenizer/`` sub-folders. This script:
|
||||
|
||||
1. loads the (sharded) ``transformer/`` weights with the vendored ``WanTransformer3DModel``;
|
||||
2. builds a :class:`LingBotVAConfig` for the target benchmark variant;
|
||||
3. instantiates a :class:`LingBotVAPolicy` and copies the transformer weights into it
|
||||
(near-identity: the only key change is the ``transformer.`` prefix);
|
||||
4. saves the LeRobot policy (``model.safetensors`` + ``config.json``) and its processors.
|
||||
|
||||
Packaging decision: only the trainable ~5B transformer is bundled into the LeRobot
|
||||
``model.safetensors``. The frozen VAE + UMT5 text encoder + tokenizer (~20 GB) are NOT
|
||||
copied; instead ``config.wan_pretrained_path`` records where to lazily pull them from at
|
||||
load time (defaults to the source repo/dir). Pass ``--bundle-frozen`` to additionally copy
|
||||
those sub-folders next to the converted checkpoint and point ``wan_pretrained_path`` at it.
|
||||
|
||||
Example (LIBERO-Long, the LIBERO eval gate):
|
||||
|
||||
python -m lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints \
|
||||
--checkpoint robbyant/lingbot-va-posttrain-libero-long \
|
||||
--variant libero \
|
||||
--output_dir outputs/lingbot_va_libero_long
|
||||
|
||||
Requires a CUDA GPU with enough RAM/VRAM to materialize the transformer; run on Linux.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
|
||||
from lerobot.policies.lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
from lerobot.policies.lingbot_va.wan_transformer import WanTransformer3DModel
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
# Per-benchmark variant presets (camera keys + action layout). Values mirror the upstream
|
||||
# configs (wan_va/configs/va_*_cfg.py).
|
||||
VARIANTS = {
|
||||
"libero": {
|
||||
"obs_cam_keys": [f"{OBS_IMAGES}.image", f"{OBS_IMAGES}.image2"],
|
||||
"height": 128,
|
||||
"width": 128,
|
||||
"action_per_frame": 4,
|
||||
"frame_chunk_size": 4,
|
||||
"attn_window": 30,
|
||||
"num_inference_steps": 20,
|
||||
"action_num_inference_steps": 50,
|
||||
"guidance_scale": 5.0,
|
||||
"action_guidance_scale": 1.0,
|
||||
"snr_shift": 5.0,
|
||||
"action_snr_shift": 0.05,
|
||||
"used_action_channel_ids": list(range(7)),
|
||||
# 7-DoF: agentview + eye-in-hand, single arm. Quantiles are the config defaults.
|
||||
"image_shape": (3, 256, 256),
|
||||
},
|
||||
"robotwin": {
|
||||
"obs_cam_keys": [
|
||||
f"{OBS_IMAGES}.cam_high",
|
||||
f"{OBS_IMAGES}.cam_left_wrist",
|
||||
f"{OBS_IMAGES}.cam_right_wrist",
|
||||
],
|
||||
"height": 256,
|
||||
"width": 320,
|
||||
"action_per_frame": 16,
|
||||
"frame_chunk_size": 2,
|
||||
"attn_window": 72,
|
||||
"num_inference_steps": 25,
|
||||
"action_num_inference_steps": 50,
|
||||
"guidance_scale": 5.0,
|
||||
"action_guidance_scale": 1.0,
|
||||
"snr_shift": 5.0,
|
||||
"action_snr_shift": 1.0,
|
||||
# RoboTwin is dual-arm; set the used channels / quantiles to match the deployed config.
|
||||
"used_action_channel_ids": list(range(14)),
|
||||
"image_shape": (3, 256, 256),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _transformer_dir(checkpoint: str) -> str:
|
||||
"""Return the path/repo that ``WanTransformer3DModel.from_pretrained`` should read."""
|
||||
p = Path(checkpoint)
|
||||
if p.is_dir():
|
||||
return str(p / "transformer")
|
||||
return checkpoint # HF repo id; use subfolder kwarg below
|
||||
|
||||
|
||||
def load_source_transformer(checkpoint: str, dtype: torch.dtype) -> WanTransformer3DModel:
|
||||
p = Path(checkpoint)
|
||||
if p.is_dir():
|
||||
return WanTransformer3DModel.from_pretrained(
|
||||
str(p / "transformer"), torch_dtype=dtype, attn_mode="torch"
|
||||
)
|
||||
return WanTransformer3DModel.from_pretrained(
|
||||
checkpoint, subfolder="transformer", torch_dtype=dtype, attn_mode="torch"
|
||||
)
|
||||
|
||||
|
||||
def build_config(variant: str, wan_pretrained_path: str, dtype: str) -> LingBotVAConfig:
|
||||
preset = VARIANTS[variant]
|
||||
n_used = len(preset["used_action_channel_ids"])
|
||||
kwargs = {
|
||||
"wan_pretrained_path": wan_pretrained_path,
|
||||
"dtype": dtype,
|
||||
"obs_cam_keys": preset["obs_cam_keys"],
|
||||
"height": preset["height"],
|
||||
"width": preset["width"],
|
||||
"action_per_frame": preset["action_per_frame"],
|
||||
"frame_chunk_size": preset["frame_chunk_size"],
|
||||
"attn_window": preset["attn_window"],
|
||||
"num_inference_steps": preset["num_inference_steps"],
|
||||
"action_num_inference_steps": preset["action_num_inference_steps"],
|
||||
"guidance_scale": preset["guidance_scale"],
|
||||
"action_guidance_scale": preset["action_guidance_scale"],
|
||||
"snr_shift": preset["snr_shift"],
|
||||
"action_snr_shift": preset["action_snr_shift"],
|
||||
"used_action_channel_ids": preset["used_action_channel_ids"],
|
||||
"device": "cpu",
|
||||
}
|
||||
if variant != "libero":
|
||||
# LIBERO keeps the config default quantiles; other variants need their own. Until the
|
||||
# exact per-channel quantiles are wired in, use a neutral [-1, 1] mapping (no rescale).
|
||||
kwargs["action_q01"] = [-1.0] * n_used
|
||||
kwargs["action_q99"] = [1.0] * n_used
|
||||
cfg = LingBotVAConfig(**kwargs)
|
||||
# Populate input/output features (cameras + action) so validate_features passes.
|
||||
img_shape = preset["image_shape"]
|
||||
cfg.input_features = {
|
||||
k: PolicyFeature(type=FeatureType.VISUAL, shape=img_shape) for k in preset["obs_cam_keys"]
|
||||
}
|
||||
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(n_used,))}
|
||||
cfg.validate_features()
|
||||
return cfg
|
||||
|
||||
|
||||
def convert(
|
||||
checkpoint: str, variant: str, output_dir: str, dtype: str, bundle_frozen: bool, push_to: str | None
|
||||
):
|
||||
torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[dtype]
|
||||
out = Path(output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Decide where frozen modules will be pulled from at load time.
|
||||
if bundle_frozen:
|
||||
wan_pretrained_path = str(out)
|
||||
_copy_frozen_subfolders(checkpoint, out)
|
||||
else:
|
||||
wan_pretrained_path = checkpoint
|
||||
|
||||
print(f"Building LingBot-VA config for variant '{variant}' (frozen modules from: {wan_pretrained_path})")
|
||||
cfg = build_config(variant, wan_pretrained_path, dtype)
|
||||
|
||||
print("Loading source transformer weights ...")
|
||||
src = load_source_transformer(checkpoint, torch_dtype)
|
||||
src_sd = src.state_dict()
|
||||
|
||||
print("Instantiating LingBotVAPolicy and copying transformer weights ...")
|
||||
# Build the policy without triggering frozen-module download by constructing directly.
|
||||
policy = LingBotVAPolicy(cfg)
|
||||
# Near-identity remap: source transformer keys -> policy "transformer.*".
|
||||
remapped = {f"transformer.{k}": v for k, v in src_sd.items()}
|
||||
missing, unexpected = policy.load_state_dict(remapped, strict=False)
|
||||
_log_load_keys(missing, unexpected)
|
||||
policy = policy.to(torch_dtype)
|
||||
|
||||
print(f"Saving converted policy to {out}")
|
||||
policy.save_pretrained(out)
|
||||
|
||||
preprocessor, postprocessor = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
preprocessor.save_pretrained(out)
|
||||
postprocessor.save_pretrained(out)
|
||||
|
||||
if push_to:
|
||||
print(f"Pushing to the Hub: {push_to}")
|
||||
policy.push_to_hub(push_to)
|
||||
preprocessor.push_to_hub(push_to)
|
||||
postprocessor.push_to_hub(push_to)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
def _copy_frozen_subfolders(checkpoint: str, out: Path):
|
||||
p = Path(checkpoint)
|
||||
if not p.is_dir():
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
p = Path(snapshot_download(checkpoint, allow_patterns=["vae/*", "text_encoder/*", "tokenizer/*"]))
|
||||
for sub in ("vae", "text_encoder", "tokenizer"):
|
||||
src_sub = p / sub
|
||||
if src_sub.is_dir():
|
||||
shutil.copytree(src_sub, out / sub, dirs_exist_ok=True)
|
||||
print(f" bundled {sub}/")
|
||||
|
||||
|
||||
def _log_load_keys(missing, unexpected):
|
||||
# The source transformer should account for every "transformer.*" key in the policy.
|
||||
if missing:
|
||||
print(
|
||||
f" [load_state_dict] {len(missing)} missing keys (expected: none for transformer). Sample: {missing[:5]}"
|
||||
)
|
||||
if unexpected:
|
||||
print(f" [load_state_dict] {len(unexpected)} unexpected keys. Sample: {unexpected[:5]}")
|
||||
if not missing and not unexpected:
|
||||
print(" [load_state_dict] perfect match (near-identity remap).")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("--checkpoint", required=True, help="HF repo id or local diffusers-style directory.")
|
||||
parser.add_argument("--variant", required=True, choices=sorted(VARIANTS.keys()))
|
||||
parser.add_argument("--output_dir", required=True)
|
||||
parser.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
|
||||
parser.add_argument(
|
||||
"--bundle-frozen",
|
||||
action="store_true",
|
||||
help="Copy the frozen vae/text_encoder/tokenizer next to the checkpoint instead of lazy-pulling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", default=None, help="Optional HF repo id to push the converted policy to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(
|
||||
checkpoint=args.checkpoint,
|
||||
variant=args.variant,
|
||||
output_dir=args.output_dir,
|
||||
dtype=args.dtype,
|
||||
bundle_frozen=args.bundle_frozen,
|
||||
push_to=args.push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,582 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""LingBot-VA policy: an autoregressive video-action world model on the Wan2.2 stack.
|
||||
|
||||
The sampling loop is a faithful re-implementation of the upstream streaming server
|
||||
(``wan_va/wan_va_server.py``) and LIBERO client (``evaluation/libero/client.py``), adapted
|
||||
to LeRobot's ``select_action`` interface:
|
||||
|
||||
* the trainable dual-stream transformer is owned as a sub-module and round-trips in the
|
||||
single ``model.safetensors`` checkpoint;
|
||||
* the frozen Wan VAE + UMT5 text encoder + tokenizer are *lazily pulled* from
|
||||
``config.wan_pretrained_path`` (not bundled), so the LeRobot checkpoint stays small;
|
||||
* ``predict_action_chunk`` runs one autoregressive chunk (video stream then action
|
||||
stream, each with CFG and its own flow-matching scheduler) and updates the KV cache;
|
||||
* ``select_action`` drains a per-step action queue and records the real observed
|
||||
keyframes that are fed back into the KV cache when the queue is refilled.
|
||||
|
||||
NOTE: matching the upstream LIBERO success rate is the Phase-5 correctness gate and must be
|
||||
validated on a CUDA GPU with the converted checkpoint (tensor-diff against upstream on
|
||||
identical inputs). The streaming path is written for single-environment eval
|
||||
(``--eval.batch_size=1``).
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
from .schedulers import FlowMatchScheduler
|
||||
from .wan_transformer import WanTransformer3DModel
|
||||
from .wan_utils import data_seq_to_patch, get_mesh_id
|
||||
from .wan_vae import WanVAEStreamingWrapper, denormalize_latents, load_text_encoder, load_tokenizer, load_vae
|
||||
|
||||
|
||||
def _torch_dtype(name: str) -> torch.dtype:
|
||||
return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[name]
|
||||
|
||||
|
||||
class LingBotVAPolicy(PreTrainedPolicy):
|
||||
"""LeRobot wrapper for the LingBot-VA autoregressive video-action world model."""
|
||||
|
||||
config_class = LingBotVAConfig
|
||||
name = "lingbot_va"
|
||||
|
||||
def __init__(self, config: LingBotVAConfig, **kwargs):
|
||||
require_package("diffusers", extra="lingbot_va")
|
||||
require_package("transformers", extra="lingbot_va")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.dtype = _torch_dtype(config.dtype)
|
||||
|
||||
# Trainable dual-stream transformer (the only sub-module saved in the LeRobot checkpoint).
|
||||
self.transformer = WanTransformer3DModel(
|
||||
patch_size=tuple(config.patch_size),
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
attention_head_dim=config.attention_head_dim,
|
||||
in_channels=config.in_channels,
|
||||
out_channels=config.out_channels,
|
||||
action_dim=config.action_dim,
|
||||
text_dim=config.text_dim,
|
||||
freq_dim=config.freq_dim,
|
||||
ffn_dim=config.ffn_dim,
|
||||
num_layers=config.num_layers,
|
||||
cross_attn_norm=config.cross_attn_norm,
|
||||
eps=config.eps,
|
||||
rope_max_seq_len=config.rope_max_seq_len,
|
||||
attn_mode=config.attn_mode,
|
||||
)
|
||||
|
||||
# Frozen modules are stored OUTSIDE the nn.Module registry (plain dict) so they are
|
||||
# neither saved into model.safetensors nor moved by ``.to()``. They are lazily loaded
|
||||
# from ``config.wan_pretrained_path`` the first time inference runs.
|
||||
self._frozen: dict = {}
|
||||
|
||||
self.last_predicted_frames: Tensor | None = None
|
||||
self.reset()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Frozen-module lazy loading (VAE + UMT5 + tokenizer)
|
||||
# ------------------------------------------------------------------
|
||||
def _ensure_frozen_modules(self):
|
||||
if self._frozen:
|
||||
return
|
||||
import os
|
||||
|
||||
path = self.config.wan_pretrained_path
|
||||
device = self.config.device
|
||||
|
||||
# Support both local diffusers-style dirs (with vae/ text_encoder/ tokenizer/ sub-folders)
|
||||
# and HF repo ids (loaders accept a subfolder kwarg, omitted here = repo root layout).
|
||||
if os.path.isdir(path):
|
||||
vae_path, te_path, tok_path = (
|
||||
os.path.join(path, n) for n in ("vae", "text_encoder", "tokenizer")
|
||||
)
|
||||
else:
|
||||
vae_path = te_path = tok_path = path
|
||||
|
||||
vae = load_vae(vae_path, torch_dtype=self.dtype, torch_device=device)
|
||||
text_encoder = load_text_encoder(te_path, torch_dtype=self.dtype, torch_device=device)
|
||||
tokenizer = load_tokenizer(tok_path)
|
||||
self._frozen = {
|
||||
"vae": vae.eval(),
|
||||
"streaming_vae": WanVAEStreamingWrapper(vae),
|
||||
"text_encoder": text_encoder.eval(),
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
|
||||
@property
|
||||
def _vae(self):
|
||||
return self._frozen["vae"]
|
||||
|
||||
@property
|
||||
def _streaming_vae(self):
|
||||
return self._frozen["streaming_vae"]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# PreTrainedPolicy API
|
||||
# ------------------------------------------------------------------
|
||||
def get_optim_params(self) -> dict:
|
||||
# Only the transformer is trainable; the VAE / text encoder stay frozen.
|
||||
return self.transformer.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Reset all per-episode streaming state (KV cache, queues, frame counter)."""
|
||||
cfg = self.config
|
||||
self._action_queue: deque = deque(maxlen=cfg.n_action_steps)
|
||||
self._obs_buffer: list = [] # keyframe camera tensors observed during the current chunk
|
||||
self._executed_actions: Tensor | None = (
|
||||
None # last chunk's actions (model-normalized) for KV feedback
|
||||
)
|
||||
self._steps_since_refill = 0
|
||||
self._frame_st_id = 0
|
||||
self._first_chunk = True
|
||||
self._prompt: str | None = None
|
||||
self._prompt_embeds = None
|
||||
self._negative_prompt_embeds = None
|
||||
self.last_predicted_frames = None
|
||||
self._use_cfg = (cfg.guidance_scale > 1) or (cfg.action_guidance_scale > 1)
|
||||
# Two independent flow-matching schedulers (video latent + action streams).
|
||||
self._scheduler = FlowMatchScheduler(shift=cfg.snr_shift, sigma_min=0.0, extra_one_step=True)
|
||||
self._action_scheduler = FlowMatchScheduler(
|
||||
shift=cfg.action_snr_shift, sigma_min=0.0, extra_one_step=True
|
||||
)
|
||||
self._scheduler.set_timesteps(1000, training=True)
|
||||
self._action_scheduler.set_timesteps(1000, training=True)
|
||||
self._cache_initialised = False
|
||||
# Clear KV cache on the (already-built) transformer, if present.
|
||||
if hasattr(self, "transformer"):
|
||||
self.transformer.clear_cache("pos")
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""Training loss. Implemented in the LingBot-VA training PR (Phase 7).
|
||||
|
||||
The flow-matching dual-stream loss needs the pre-extracted latent dataset
|
||||
(see ``LatentLeRobotDataset`` upstream) and ``attn_mode='flex'``; it is intentionally
|
||||
not part of this inference-focused integration.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"LingBot-VA training (flow-matching dual-stream loss) is part of the training port "
|
||||
"(Phase 7 / PR #2) and is not implemented in this inference integration. "
|
||||
"Use this policy for evaluation / inference."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
"""Return one action, refilling the chunk (and feeding back observed keyframes) as needed."""
|
||||
self.eval()
|
||||
self._ensure_frozen_modules()
|
||||
self._maybe_init_prompt(batch)
|
||||
|
||||
# Record the current observation as a keyframe at every frame boundary so that, when the
|
||||
# queue empties, ``predict_action_chunk`` can feed the real observed frames back into the
|
||||
# KV cache (mirroring the upstream ``compute_kv_cache`` call in the LIBERO client loop).
|
||||
# We skip ``steps_since_refill == 0`` (the obs that conditioned the current chunk): only
|
||||
# frames observed *after* executing each frame's actions are fed back.
|
||||
if self._steps_since_refill > 0 and self._steps_since_refill % self.config.action_per_frame == 0:
|
||||
self._obs_buffer.append(self._encode_obs(batch))
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch) # [B, chunk_size, n_used]
|
||||
# queue holds per-step actions: shape [chunk_size, B, n_used]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
self._steps_since_refill = 0
|
||||
|
||||
self._steps_since_refill += 1
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
"""Run one autoregressive chunk and return actions ``[B, chunk_size, n_used]`` (normalized)."""
|
||||
self.eval()
|
||||
self._ensure_frozen_modules()
|
||||
self._maybe_init_prompt(batch)
|
||||
|
||||
is_first = self._first_chunk
|
||||
if is_first:
|
||||
init_latent = self._encode_obs(batch)
|
||||
self._init_latent = init_latent
|
||||
self._init_streaming_cache(init_latent)
|
||||
self._obs_buffer = [] # frame 0 (the init obs) conditions the chunk; it is not fed back
|
||||
actions, latents = self._infer(init_latent, frame_st_id=0)
|
||||
self._first_chunk = False
|
||||
else:
|
||||
# Feed the real observed keyframes + the executed actions back into the KV cache.
|
||||
self._compute_kv_cache(self._obs_buffer, self._executed_actions)
|
||||
self._obs_buffer = []
|
||||
actions, latents = self._infer(None, frame_st_id=self._frame_st_id)
|
||||
|
||||
# actions: [B, action_dim, F, action_per_frame, 1] (model-normalized). Keep for KV feedback.
|
||||
self._executed_actions = actions
|
||||
|
||||
if self.config.save_predicted_video:
|
||||
self.last_predicted_frames = self._decode_predicted_video(latents)
|
||||
|
||||
# On the first chunk, frame 0 is the conditioning frame (already "known"): the upstream
|
||||
# LIBERO client skips it (start_idx=1), so we drop the first frame's actions here.
|
||||
used = self.config.used_action_channel_ids
|
||||
a = actions[:, used] # [B, n_used, F, action_per_frame, 1]
|
||||
if is_first:
|
||||
a = a[:, :, 1:] # drop frame 0 -> (F-1) frames of actions
|
||||
a = a.squeeze(-1).flatten(2) # [B, n_used, n_steps]
|
||||
a = a.transpose(1, 2).contiguous() # [B, n_steps, n_used]
|
||||
return a.to(torch.float32)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt / text encoding
|
||||
# ------------------------------------------------------------------
|
||||
def _maybe_init_prompt(self, batch):
|
||||
if self._prompt_embeds is not None:
|
||||
return
|
||||
task = batch.get("task")
|
||||
prompt = task[0] if isinstance(task, list | tuple) else task
|
||||
self._prompt = prompt or ""
|
||||
self._prompt_embeds, self._negative_prompt_embeds = self._encode_prompt(self._prompt)
|
||||
|
||||
def _get_t5_prompt_embeds(self, prompt, max_sequence_length):
|
||||
from diffusers.pipelines.wan.pipeline_wan import prompt_clean
|
||||
|
||||
tokenizer = self._frozen["tokenizer"]
|
||||
text_encoder = self._frozen["text_encoder"]
|
||||
device = self.config.device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
te_device = next(text_encoder.parameters()).device
|
||||
prompt_embeds = text_encoder(text_input_ids.to(te_device), mask.to(te_device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens, strict=False)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds],
|
||||
dim=0,
|
||||
)
|
||||
return prompt_embeds.to(device)
|
||||
|
||||
def _encode_prompt(self, prompt):
|
||||
max_len = self.config.max_sequence_length
|
||||
prompt_embeds = self._get_t5_prompt_embeds(prompt, max_len)
|
||||
negative_prompt_embeds = None
|
||||
if self._use_cfg:
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds("", max_len)
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Observation (image) encoding -> normalized video latents
|
||||
# ------------------------------------------------------------------
|
||||
def _camera_tensor(self, batch, key):
|
||||
"""Return a single-frame camera tensor [B, C, 1, H, W] resized + scaled to [-1, 1]."""
|
||||
img = batch[key]
|
||||
if img.dim() == 3: # [C, H, W]
|
||||
img = img.unsqueeze(0)
|
||||
# LeRobot images arrive as float in [0, 1], shape [B, C, H, W].
|
||||
img = img.to(self.config.device, torch.float32)
|
||||
img = F.interpolate(
|
||||
img, size=(self.config.height, self.config.width), mode="bilinear", align_corners=False
|
||||
)
|
||||
img = img * 2.0 - 1.0
|
||||
return img.unsqueeze(2).to(self.dtype) # [B, C, F=1, H, W]
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_obs(self, batch) -> Tensor:
|
||||
"""VAE-encode all configured cameras of the current obs and concat latents on width."""
|
||||
videos = [self._camera_tensor(batch, k) for k in self.config.obs_cam_keys]
|
||||
videos = torch.cat(videos, dim=0) # [num_cam, C, F, H, W]
|
||||
vae_device = next(self._vae.parameters()).device
|
||||
enc_out = self._streaming_vae.encode_chunk(videos.to(vae_device).to(self.dtype))
|
||||
mu, _logvar = torch.chunk(enc_out, 2, dim=1)
|
||||
latents_mean = torch.tensor(self._vae.config.latents_mean).to(mu.device)
|
||||
latents_std = torch.tensor(self._vae.config.latents_std).to(mu.device)
|
||||
# Note: upstream passes 1/std so the op is (x - mean) * (1/std).
|
||||
mean = latents_mean.view(1, -1, 1, 1, 1)
|
||||
inv_std = (1.0 / latents_std).view(1, -1, 1, 1, 1)
|
||||
mu_norm = ((mu.float() - mean) * inv_std).to(mu)
|
||||
# Concatenate the per-camera latents along width.
|
||||
video_latent = torch.cat(mu_norm.split(1, dim=0), dim=-1)
|
||||
return video_latent.to(self.config.device)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# KV cache management
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def _latent_hw(self):
|
||||
h = self.config.height // 16
|
||||
w = (self.config.width // 16) * len(self.config.obs_cam_keys)
|
||||
return h, w
|
||||
|
||||
def _init_streaming_cache(self, init_latent):
|
||||
cfg = self.config
|
||||
latent_h, latent_w = self._latent_hw
|
||||
p = cfg.patch_size
|
||||
latent_token_per_chunk = (cfg.frame_chunk_size * latent_h * latent_w) // (p[0] * p[1] * p[2])
|
||||
action_token_per_chunk = cfg.frame_chunk_size * cfg.action_per_frame
|
||||
self.transformer.create_empty_cache(
|
||||
"pos",
|
||||
cfg.attn_window,
|
||||
latent_token_per_chunk,
|
||||
action_token_per_chunk,
|
||||
device=self.config.device,
|
||||
dtype=self.dtype,
|
||||
batch_size=2 if self._use_cfg else 1,
|
||||
)
|
||||
self._cache_initialised = True
|
||||
|
||||
def _repeat_input_for_cfg(self, input_dict):
|
||||
if self._use_cfg:
|
||||
input_dict["noisy_latents"] = input_dict["noisy_latents"].repeat(2, 1, 1, 1, 1)
|
||||
input_dict["text_emb"] = torch.cat(
|
||||
[
|
||||
self._prompt_embeds.to(self.dtype).clone(),
|
||||
self._negative_prompt_embeds.to(self.dtype).clone(),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
input_dict["grid_id"] = input_dict["grid_id"][None].repeat(2, 1, 1)
|
||||
input_dict["timesteps"] = input_dict["timesteps"][None].repeat(2, 1)
|
||||
else:
|
||||
input_dict["grid_id"] = input_dict["grid_id"][None]
|
||||
input_dict["timesteps"] = input_dict["timesteps"][None]
|
||||
return input_dict
|
||||
|
||||
def _prepare_latent_input(
|
||||
self,
|
||||
latent_model_input,
|
||||
action_model_input,
|
||||
latent_t=0,
|
||||
action_t=0,
|
||||
latent_cond=None,
|
||||
action_cond=None,
|
||||
frame_st_id=0,
|
||||
):
|
||||
cfg = self.config
|
||||
device = self.config.device
|
||||
p = cfg.patch_size
|
||||
out = {}
|
||||
if latent_model_input is not None:
|
||||
out["latent_res_lst"] = {
|
||||
"noisy_latents": latent_model_input,
|
||||
"timesteps": torch.ones([latent_model_input.shape[2]], dtype=torch.float32, device=device)
|
||||
* latent_t,
|
||||
"grid_id": get_mesh_id(
|
||||
latent_model_input.shape[-3] // p[0],
|
||||
latent_model_input.shape[-2] // p[1],
|
||||
latent_model_input.shape[-1] // p[2],
|
||||
0,
|
||||
1,
|
||||
frame_st_id,
|
||||
).to(device),
|
||||
"text_emb": self._prompt_embeds.to(self.dtype).clone(),
|
||||
}
|
||||
if latent_cond is not None:
|
||||
out["latent_res_lst"]["noisy_latents"][:, :, 0:1] = latent_cond[:, :, 0:1]
|
||||
out["latent_res_lst"]["timesteps"][0:1] *= 0
|
||||
if action_model_input is not None:
|
||||
out["action_res_lst"] = {
|
||||
"noisy_latents": action_model_input,
|
||||
"timesteps": torch.ones([action_model_input.shape[2]], dtype=torch.float32, device=device)
|
||||
* action_t,
|
||||
"grid_id": get_mesh_id(
|
||||
action_model_input.shape[-3],
|
||||
action_model_input.shape[-2],
|
||||
action_model_input.shape[-1],
|
||||
1,
|
||||
1,
|
||||
frame_st_id,
|
||||
action=True,
|
||||
).to(device),
|
||||
"text_emb": self._prompt_embeds.to(self.dtype).clone(),
|
||||
}
|
||||
if action_cond is not None:
|
||||
out["action_res_lst"]["noisy_latents"][:, :, 0:1] = action_cond[:, :, 0:1]
|
||||
out["action_res_lst"]["timesteps"][0:1] *= 0
|
||||
out["action_res_lst"]["noisy_latents"][:, ~self._action_mask] *= 0
|
||||
return out
|
||||
|
||||
@property
|
||||
def _action_mask(self):
|
||||
mask = torch.zeros([self.config.action_dim], dtype=torch.bool)
|
||||
mask[self.config.used_action_channel_ids] = True
|
||||
return mask
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action conditioning (executed action history) (de)normalization
|
||||
# ------------------------------------------------------------------
|
||||
def _preprocess_action_state(self, action_norm: Tensor) -> Tensor:
|
||||
"""Build the action-conditioning tensor from the already-normalized executed actions.
|
||||
|
||||
``action_norm`` is the model-space action chunk ``[B, action_dim, F, action_per_frame, 1]``.
|
||||
Upstream re-derives the conditioning from the raw executed action via quantile norm; here
|
||||
the executed actions are already in the model-normalized space, so we pass them through.
|
||||
"""
|
||||
return action_norm.to(self.config.device, self.dtype)
|
||||
|
||||
def _compute_kv_cache(self, obs_buffer, executed_actions):
|
||||
"""Feed real observed keyframes + executed actions back into the KV cache."""
|
||||
if not obs_buffer or executed_actions is None:
|
||||
return
|
||||
self.transformer.clear_pred_cache("pos")
|
||||
# Concatenate the observed keyframe latents along the frame axis.
|
||||
latent_model_input = torch.cat(obs_buffer, dim=2)
|
||||
# On the first feedback, prepend the init latent so the latent/action frame counts align
|
||||
# (upstream prepends ``init_latent`` to the observed keyframes when frame_st_id == 0).
|
||||
if self._frame_st_id == 0 and getattr(self, "_init_latent", None) is not None:
|
||||
latent_model_input = torch.cat([self._init_latent, latent_model_input], dim=2)
|
||||
action_model_input = self._preprocess_action_state(executed_actions)
|
||||
action_model_input = action_model_input.to(latent_model_input)
|
||||
input_dict = self._prepare_latent_input(
|
||||
latent_model_input, action_model_input, frame_st_id=self._frame_st_id
|
||||
)
|
||||
with torch.no_grad():
|
||||
self.transformer(
|
||||
self._repeat_input_for_cfg(input_dict["latent_res_lst"]),
|
||||
update_cache=2,
|
||||
cache_name="pos",
|
||||
action_mode=False,
|
||||
)
|
||||
self.transformer(
|
||||
self._repeat_input_for_cfg(input_dict["action_res_lst"]),
|
||||
update_cache=2,
|
||||
cache_name="pos",
|
||||
action_mode=True,
|
||||
)
|
||||
self._frame_st_id += latent_model_input.shape[2]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# The core dual-stream denoising loop (one chunk)
|
||||
# ------------------------------------------------------------------
|
||||
@torch.no_grad()
|
||||
def _infer(self, init_latent, frame_st_id=0):
|
||||
cfg = self.config
|
||||
device = self.config.device
|
||||
latent_h, latent_w = self._latent_hw
|
||||
frame_chunk_size = cfg.frame_chunk_size
|
||||
|
||||
latents = torch.randn(1, 48, frame_chunk_size, latent_h, latent_w, device=device, dtype=self.dtype)
|
||||
actions = torch.randn(
|
||||
1, cfg.action_dim, frame_chunk_size, cfg.action_per_frame, 1, device=device, dtype=self.dtype
|
||||
)
|
||||
|
||||
self._scheduler.set_timesteps(cfg.num_inference_steps)
|
||||
self._action_scheduler.set_timesteps(cfg.action_num_inference_steps)
|
||||
timesteps = F.pad(self._scheduler.timesteps, (0, 1), mode="constant", value=0)
|
||||
if cfg.video_exec_step != -1:
|
||||
timesteps = timesteps[: cfg.video_exec_step]
|
||||
action_timesteps = F.pad(self._action_scheduler.timesteps, (0, 1), mode="constant", value=0)
|
||||
|
||||
# 1. Video-latent denoising loop
|
||||
for i, t in enumerate(timesteps):
|
||||
last_step = i == len(timesteps) - 1
|
||||
latent_cond = (
|
||||
init_latent[:, :, 0:1].to(self.dtype)
|
||||
if frame_st_id == 0 and init_latent is not None
|
||||
else None
|
||||
)
|
||||
input_dict = self._prepare_latent_input(
|
||||
latents, None, t, t, latent_cond, None, frame_st_id=frame_st_id
|
||||
)
|
||||
video_noise_pred = self.transformer(
|
||||
self._repeat_input_for_cfg(input_dict["latent_res_lst"]),
|
||||
update_cache=1 if last_step else 0,
|
||||
cache_name="pos",
|
||||
action_mode=False,
|
||||
)
|
||||
if not last_step or cfg.video_exec_step != -1:
|
||||
video_noise_pred = data_seq_to_patch(
|
||||
cfg.patch_size,
|
||||
video_noise_pred,
|
||||
frame_chunk_size,
|
||||
latent_h,
|
||||
latent_w,
|
||||
batch_size=2 if self._use_cfg else 1,
|
||||
)
|
||||
if cfg.guidance_scale > 1:
|
||||
video_noise_pred = video_noise_pred[1:] + cfg.guidance_scale * (
|
||||
video_noise_pred[:1] - video_noise_pred[1:]
|
||||
)
|
||||
else:
|
||||
video_noise_pred = video_noise_pred[:1]
|
||||
latents = self._scheduler.step(video_noise_pred, t, latents, return_dict=False)
|
||||
if frame_st_id == 0 and latent_cond is not None:
|
||||
latents[:, :, 0:1] = latent_cond
|
||||
|
||||
# 2. Action denoising loop
|
||||
for i, t in enumerate(action_timesteps):
|
||||
last_step = i == len(action_timesteps) - 1
|
||||
action_cond = (
|
||||
torch.zeros([1, cfg.action_dim, 1, cfg.action_per_frame, 1], device=device, dtype=self.dtype)
|
||||
if frame_st_id == 0
|
||||
else None
|
||||
)
|
||||
input_dict = self._prepare_latent_input(
|
||||
None, actions, t, t, None, action_cond, frame_st_id=frame_st_id
|
||||
)
|
||||
action_noise_pred = self.transformer(
|
||||
self._repeat_input_for_cfg(input_dict["action_res_lst"]),
|
||||
update_cache=1 if last_step else 0,
|
||||
cache_name="pos",
|
||||
action_mode=True,
|
||||
)
|
||||
if not last_step:
|
||||
from einops import rearrange
|
||||
|
||||
action_noise_pred = rearrange(action_noise_pred, "b (f n) c -> b c f n 1", f=frame_chunk_size)
|
||||
if cfg.action_guidance_scale > 1:
|
||||
action_noise_pred = action_noise_pred[1:] + cfg.action_guidance_scale * (
|
||||
action_noise_pred[:1] - action_noise_pred[1:]
|
||||
)
|
||||
else:
|
||||
action_noise_pred = action_noise_pred[:1]
|
||||
actions = self._action_scheduler.step(action_noise_pred, t, actions, return_dict=False)
|
||||
if frame_st_id == 0 and action_cond is not None:
|
||||
actions[:, :, 0:1] = action_cond
|
||||
|
||||
actions[:, ~self._action_mask] *= 0
|
||||
return actions, latents
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Predicted-video decoding (opt-in)
|
||||
# ------------------------------------------------------------------
|
||||
@torch.no_grad()
|
||||
def _decode_predicted_video(self, latents) -> Tensor:
|
||||
"""VAE-decode predicted latents into a uint8 frame stack ``[T, H, W, 3]`` on CPU."""
|
||||
vae = self._vae
|
||||
z_dim = vae.config.z_dim
|
||||
latents = denormalize_latents(
|
||||
latents.to(vae.dtype), vae.config.latents_mean, vae.config.latents_std, z_dim
|
||||
)
|
||||
video = vae.decode(latents, return_dict=False)[0] # [B, C, F, H, W] in [-1, 1]
|
||||
video = (video.float().clamp(-1, 1) + 1.0) / 2.0
|
||||
video = (video[0].permute(1, 2, 3, 0) * 255.0).round().to(torch.uint8) # [F, H, W, C]
|
||||
return video.cpu()
|
||||
@@ -0,0 +1,113 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pre/post-processor pipelines for the LingBot-VA policy.
|
||||
|
||||
The policy itself handles image resizing, scaling to [-1, 1] and VAE encoding (the VAE
|
||||
lives inside the policy), so the preprocessor only renames, batches, normalizes (IDENTITY)
|
||||
and moves to device. The postprocessor reverses the *fixed* action quantile normalization
|
||||
(``(action + 1) / 2 * (q99 - q01 + 1e-6) + q01``) baked into the released checkpoints — this
|
||||
is a fixed transform, not a dataset-stats one, so it cannot use the standard
|
||||
``UnnormalizerProcessorStep`` and is implemented as a dedicated step below.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyActionProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import (
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="lingbot_va_action_unnormalize")
|
||||
class LingBotVAActionUnnormalizeStep(PolicyActionProcessorStep):
|
||||
"""Reverse LingBot-VA's fixed per-channel quantile normalization on predicted actions.
|
||||
|
||||
The policy emits actions in the normalized ``[-1, 1]`` space of the used action channels.
|
||||
This step maps them back to physical units via the fixed quantiles stored in the config.
|
||||
"""
|
||||
|
||||
action_q01: list[float] = field(default_factory=list)
|
||||
action_q99: list[float] = field(default_factory=list)
|
||||
|
||||
def action(self, action: PolicyAction) -> PolicyAction:
|
||||
q01 = torch.as_tensor(self.action_q01, dtype=action.dtype, device=action.device)
|
||||
q99 = torch.as_tensor(self.action_q99, dtype=action.dtype, device=action.device)
|
||||
return (action + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"action_q01": list(self.action_q01), "action_q99": list(self.action_q99)}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def make_lingbot_va_pre_post_processors(
|
||||
config: LingBotVAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build the pre/post processor pipelines for LingBot-VA."""
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
LingBotVAActionUnnormalizeStep(action_q01=config.action_q01, action_q99=config.action_q99),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,155 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Flow-matching scheduler for LingBot-VA.
|
||||
|
||||
Vendored verbatim from the upstream LingBot-VA repository
|
||||
(https://github.com/Robbyant/lingbot-va, ``wan_va/utils/scheduler.py``). LingBot-VA uses
|
||||
two independent instances of this scheduler at inference time — one for the video-latent
|
||||
stream and one for the action stream — each with its own ``shift`` (signal-to-noise ratio
|
||||
shift) and number of denoising steps.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["FlowMatchScheduler"]
|
||||
|
||||
|
||||
class FlowMatchScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps=100,
|
||||
num_train_timesteps=1000,
|
||||
shift=3.0,
|
||||
sigma_max=1.0,
|
||||
sigma_min=0.003 / 1.002,
|
||||
inverse_timesteps=False,
|
||||
extra_one_step=False,
|
||||
reverse_sigmas=False,
|
||||
exponential_shift=False,
|
||||
exponential_shift_mu=None,
|
||||
shift_terminal=None,
|
||||
):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.shift = shift
|
||||
self.sigma_max = sigma_max
|
||||
self.sigma_min = sigma_min
|
||||
self.inverse_timesteps = inverse_timesteps
|
||||
self.extra_one_step = extra_one_step
|
||||
self.reverse_sigmas = reverse_sigmas
|
||||
self.exponential_shift = exponential_shift
|
||||
self.exponential_shift_mu = exponential_shift_mu
|
||||
self.shift_terminal = shift_terminal
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps=100,
|
||||
denoising_strength=1.0,
|
||||
training=False,
|
||||
shift=None,
|
||||
dynamic_shift_len=None,
|
||||
):
|
||||
if shift is not None:
|
||||
self.shift = shift
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
if self.extra_one_step:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||
else:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
if self.inverse_timesteps:
|
||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||
if self.exponential_shift:
|
||||
mu = (
|
||||
self.calculate_shift(dynamic_shift_len)
|
||||
if dynamic_shift_len is not None
|
||||
else self.exponential_shift_mu
|
||||
)
|
||||
self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
|
||||
else:
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
if self.shift_terminal is not None:
|
||||
one_minus_z = 1 - self.sigmas
|
||||
scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
|
||||
self.sigmas = 1 - (one_minus_z / scale_factor)
|
||||
if self.reverse_sigmas:
|
||||
self.sigmas = 1 - self.sigmas
|
||||
self.timesteps = self.sigmas * self.num_train_timesteps
|
||||
if training:
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
self.training = True
|
||||
else:
|
||||
self.training = False
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
||||
else:
|
||||
sigma_ = self.sigmas[timestep_id + 1]
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
return prev_sample
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
model_output = (sample - sample_stablized) / sigma
|
||||
return model_output
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep, t_dim=2):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep = timestep[None]
|
||||
timestep_id = torch.argmin((self.timesteps[:, None] - timestep).abs(), dim=0)
|
||||
shape = [1] * noise.ndim
|
||||
shape[t_dim] = timestep_id.shape[0]
|
||||
sigma = self.sigmas[timestep_id].to(original_samples).view(shape)
|
||||
sample = (1 - sigma) * original_samples + sigma * noise
|
||||
return sample
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
target = noise - sample
|
||||
return target
|
||||
|
||||
def training_weight(self, timestep):
|
||||
timestep_id = torch.argmin(
|
||||
(self.timesteps[:, None].to(timestep.device) - timestep[None]).abs(), dim=0
|
||||
)
|
||||
weights = self.linear_timesteps_weights.to(timestep.device)[timestep_id].to(timestep.device)
|
||||
return weights
|
||||
|
||||
def calculate_shift(
|
||||
self,
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 8192,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 0.9,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
@@ -0,0 +1,286 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Attention and rotary-position-embedding modules for the LingBot-VA Wan transformer.
|
||||
|
||||
Vendored and lightly adapted from the upstream LingBot-VA repository
|
||||
(https://github.com/Robbyant/lingbot-va, ``wan_va/modules/model.py``). The ``torch``
|
||||
SDPA backend is the default and is always available; the ``flashattn`` and ``flex``
|
||||
backends are imported lazily and only required when the corresponding ``attn_mode`` is
|
||||
selected. State-dict parameter names are preserved verbatim so that conversion from the
|
||||
original diffusers-style checkpoint is near-identity.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# ``flash_attn`` and the flex-attention APIs are optional. We import them lazily inside the
|
||||
# backends that need them so that the (default) ``torch`` SDPA path works on any platform,
|
||||
# including CPU-only and macOS where neither package is available.
|
||||
|
||||
|
||||
def custom_sdpa(q, k, v):
|
||||
"""Scaled-dot-product attention operating on ``(B, S, H, D)`` tensors."""
|
||||
out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2))
|
||||
return out.transpose(1, 2)
|
||||
|
||||
|
||||
def _load_flash_attn_func():
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func
|
||||
except ImportError:
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"attn_mode='flashattn' requires the `flash_attn` package, which is not installed. "
|
||||
"Install it, or use attn_mode='torch' (the default)."
|
||||
) from e
|
||||
return flash_attn_func
|
||||
|
||||
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
"""Rotary position embedding with separate frequency bases for frame / height / width."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size,
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.theta = theta
|
||||
|
||||
self.f_dim = self.attention_head_dim - 2 * (self.attention_head_dim // 3)
|
||||
self.h_dim = self.attention_head_dim // 3
|
||||
self.w_dim = self.attention_head_dim // 3
|
||||
|
||||
f_freqs_base, h_freqs_base, w_freqs_base = self._precompute_freqs_base()
|
||||
self.f_freqs_base = f_freqs_base
|
||||
self.h_freqs_base = h_freqs_base
|
||||
self.w_freqs_base = w_freqs_base
|
||||
|
||||
def _precompute_freqs_base(self):
|
||||
# freqs_base = 1.0 / (theta ** (2k / dim))
|
||||
f_freqs_base = 1.0 / (
|
||||
self.theta ** (torch.arange(0, self.f_dim, 2)[: (self.f_dim // 2)].double() / self.f_dim)
|
||||
)
|
||||
h_freqs_base = 1.0 / (
|
||||
self.theta ** (torch.arange(0, self.h_dim, 2)[: (self.h_dim // 2)].double() / self.h_dim)
|
||||
)
|
||||
w_freqs_base = 1.0 / (
|
||||
self.theta ** (torch.arange(0, self.w_dim, 2)[: (self.w_dim // 2)].double() / self.w_dim)
|
||||
)
|
||||
return f_freqs_base, h_freqs_base, w_freqs_base
|
||||
|
||||
def forward(self, grid_ids):
|
||||
with torch.no_grad():
|
||||
f_freqs = grid_ids[:, 0, :].unsqueeze(-1) * self.f_freqs_base.to(grid_ids.device)
|
||||
h_freqs = grid_ids[:, 1, :].unsqueeze(-1) * self.h_freqs_base.to(grid_ids.device)
|
||||
w_freqs = grid_ids[:, 2, :].unsqueeze(-1) * self.w_freqs_base.to(grid_ids.device)
|
||||
freqs = torch.cat([f_freqs, h_freqs, w_freqs], dim=-1).float()
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class WanAttention(nn.Module):
|
||||
"""Self/cross attention with KV-caching for autoregressive streaming inference.
|
||||
|
||||
Backends:
|
||||
* ``torch`` (default): standard SDPA, available everywhere.
|
||||
* ``flashattn``: FlashAttention kernels (optional dependency).
|
||||
* ``flex``: PyTorch flex-attention (optional, used for block-causal training masks).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
eps=1e-5,
|
||||
dropout=0.0,
|
||||
cross_attention_dim_head=None,
|
||||
attn_mode="torch",
|
||||
):
|
||||
super().__init__()
|
||||
if attn_mode == "torch":
|
||||
self.attn_op = custom_sdpa
|
||||
elif attn_mode == "flashattn":
|
||||
self.attn_op = _load_flash_attn_func()
|
||||
elif attn_mode == "flex":
|
||||
# Imported lazily to avoid a hard dependency on torch flex-attention at import time.
|
||||
from .wan_flex_attention import FlexAttnFunc
|
||||
|
||||
self.attn_op = FlexAttnFunc(cross_attention_dim_head is not None)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported attention mode: {attn_mode}, only support 'torch', 'flashattn' and 'flex'"
|
||||
)
|
||||
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.cross_attention_dim_head = cross_attention_dim_head
|
||||
self.kv_inner_dim = (
|
||||
self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
||||
)
|
||||
|
||||
self.to_q = nn.Linear(dim, self.inner_dim, bias=True)
|
||||
self.to_k = nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_v = nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_out = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(self.inner_dim, dim, bias=True),
|
||||
nn.Dropout(dropout),
|
||||
]
|
||||
)
|
||||
self.norm_q = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
self.norm_k = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
# KV cache only lives on self-attention modules (cross_attention_dim_head is None).
|
||||
self.attn_caches = {} if cross_attention_dim_head is None else None
|
||||
|
||||
def clear_pred_cache(self, cache_name):
|
||||
if self.attn_caches is None:
|
||||
return
|
||||
cache = self.attn_caches[cache_name]
|
||||
is_pred = cache["is_pred"]
|
||||
cache["mask"][is_pred] = False
|
||||
|
||||
def clear_cache(self, cache_name):
|
||||
if self.attn_caches is None:
|
||||
return
|
||||
self.attn_caches[cache_name] = None
|
||||
|
||||
def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim, device, dtype, batch_size):
|
||||
if self.attn_caches is None:
|
||||
return
|
||||
self.attn_caches[cache_name] = {
|
||||
"k": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype),
|
||||
"v": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype),
|
||||
"id": torch.full((total_tolen,), -1, device=device),
|
||||
"mask": torch.zeros((total_tolen,), dtype=torch.bool, device=device),
|
||||
"is_pred": torch.zeros((total_tolen,), dtype=torch.bool, device=device),
|
||||
}
|
||||
|
||||
def allocate_slots(self, cache_name, key_size):
|
||||
cache = self.attn_caches[cache_name]
|
||||
mask = cache["mask"]
|
||||
ids = cache["id"]
|
||||
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
|
||||
|
||||
if free.numel() < key_size:
|
||||
used = mask.nonzero(as_tuple=False).squeeze(-1)
|
||||
|
||||
used_ids = ids[used]
|
||||
order = torch.argsort(used_ids)
|
||||
need = key_size - free.numel()
|
||||
to_free = used[order[:need]]
|
||||
|
||||
mask[to_free] = False
|
||||
ids[to_free] = -1
|
||||
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
|
||||
|
||||
assert free.numel() >= key_size
|
||||
return free[:key_size]
|
||||
|
||||
def _next_cache_id(self, cache_name):
|
||||
ids = self.attn_caches[cache_name]["id"]
|
||||
mask = self.attn_caches[cache_name]["mask"]
|
||||
|
||||
if mask.any():
|
||||
return ids[mask].max() + 1
|
||||
else:
|
||||
return torch.tensor(0, device=ids.device, dtype=ids.dtype)
|
||||
|
||||
def update_cache(self, cache_name, key, value, is_pred):
|
||||
cache = self.attn_caches[cache_name]
|
||||
|
||||
key_size = key.shape[1]
|
||||
slots = self.allocate_slots(cache_name, key_size)
|
||||
|
||||
new_id = self._next_cache_id(cache_name)
|
||||
|
||||
cache["k"][:, slots] = key
|
||||
cache["v"][:, slots] = value
|
||||
cache["mask"][slots] = True
|
||||
cache["id"][slots] = new_id
|
||||
cache["is_pred"][slots] = is_pred
|
||||
return slots
|
||||
|
||||
def restore_cache(self, cache_name, slots):
|
||||
self.attn_caches[cache_name]["mask"][slots] = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
rotary_emb,
|
||||
update_cache=0,
|
||||
cache_name="pos",
|
||||
):
|
||||
kv_cache = (
|
||||
self.attn_caches[cache_name]
|
||||
if (self.attn_caches is not None) and (cache_name in self.attn_caches)
|
||||
else None
|
||||
)
|
||||
|
||||
query, key, value = self.to_q(q), self.to_k(k), self.to_v(v)
|
||||
query = self.norm_q(query)
|
||||
query = query.unflatten(2, (self.heads, -1))
|
||||
key = self.norm_k(key)
|
||||
key = key.unflatten(2, (self.heads, -1))
|
||||
value = value.unflatten(2, (self.heads, -1))
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(x, freqs):
|
||||
x_out = torch.view_as_complex(
|
||||
x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
||||
)
|
||||
x_out = torch.view_as_real(x_out * freqs).flatten(3)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
query = apply_rotary_emb(query, rotary_emb)
|
||||
key = apply_rotary_emb(key, rotary_emb)
|
||||
slots = None
|
||||
if kv_cache is not None and kv_cache["k"] is not None:
|
||||
slots = self.update_cache(cache_name, key, value, is_pred=(update_cache == 1))
|
||||
key_pool = self.attn_caches[cache_name]["k"]
|
||||
value_pool = self.attn_caches[cache_name]["v"]
|
||||
mask = self.attn_caches[cache_name]["mask"]
|
||||
valid = mask.nonzero(as_tuple=False).squeeze(-1)
|
||||
key = key_pool[:, valid]
|
||||
value = value_pool[:, valid]
|
||||
|
||||
hidden_states = self.attn_op(query, key, value)
|
||||
|
||||
if update_cache == 0:
|
||||
if kv_cache is not None and kv_cache["k"] is not None:
|
||||
self.restore_cache(cache_name, slots)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
__all__ = ["WanAttention", "WanRotaryPosEmbed", "custom_sdpa"]
|
||||
@@ -0,0 +1,207 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Flex-attention backend for the LingBot-VA Wan transformer (training only).
|
||||
|
||||
This module is imported lazily and ONLY when ``attn_mode='flex'`` is requested. It builds
|
||||
the block-causal / window / noise-vs-clean attention masks used during the dual-stream
|
||||
flow-matching training described in the LingBot-VA paper. Inference uses the ``torch``
|
||||
SDPA backend (see :mod:`wan_attention`) which does not need flex-attention.
|
||||
|
||||
``torch.nn.attention.flex_attention`` requires a recent PyTorch build with the relevant
|
||||
inductor support; importing this module on an unsupported build raises ``ImportError``.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
and_masks,
|
||||
create_block_mask,
|
||||
flex_attention,
|
||||
or_masks,
|
||||
)
|
||||
|
||||
|
||||
class FlexAttnFunc(nn.Module):
|
||||
flex_attn: ClassVar[Callable] = torch.compile(flex_attention, dynamic=True)
|
||||
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
|
||||
attention_mask: ClassVar[BlockMask] = None
|
||||
cross_attention_mask: ClassVar[BlockMask] = None
|
||||
|
||||
def __init__(self, is_cross=False) -> None:
|
||||
super().__init__()
|
||||
self.is_cross = is_cross
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dtype=torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
q_varlen = rearrange(query[0], "s n d -> 1 n s d")
|
||||
k_varlen = rearrange(key[0], "s n d -> 1 n s d")
|
||||
v_varlen = rearrange(value[0], "s n d -> 1 n s d")
|
||||
|
||||
half_dtypes = (torch.float16, torch.bfloat16)
|
||||
assert dtype in half_dtypes
|
||||
|
||||
def half(x):
|
||||
return x if x.dtype in half_dtypes else x.to(dtype)
|
||||
|
||||
q_varlen = half(q_varlen)
|
||||
k_varlen = half(k_varlen)
|
||||
v_varlen = half(v_varlen)
|
||||
q_varlen = q_varlen.to(v_varlen.dtype)
|
||||
k_varlen = k_varlen.to(v_varlen.dtype)
|
||||
|
||||
block_mask = FlexAttnFunc.cross_attention_mask if self.is_cross else FlexAttnFunc.attention_mask
|
||||
|
||||
x_out = FlexAttnFunc.flex_attn(
|
||||
q_varlen,
|
||||
k_varlen,
|
||||
v_varlen,
|
||||
block_mask=block_mask,
|
||||
kernel_options={
|
||||
"BLOCK_M": 64,
|
||||
"BLOCK_N": 64,
|
||||
"BLOCK_M1": 32,
|
||||
"BLOCK_N1": 64,
|
||||
"BLOCK_M2": 64,
|
||||
"BLOCK_N2": 32,
|
||||
},
|
||||
)
|
||||
|
||||
x_out = rearrange(x_out, "b n s d -> b s n d")
|
||||
return x_out
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def init_mask(
|
||||
latent_shape,
|
||||
action_shape,
|
||||
padded_length,
|
||||
chunk_size,
|
||||
window_size,
|
||||
patch_size,
|
||||
device,
|
||||
):
|
||||
torch._inductor.config.realize_opcount_threshold = 100
|
||||
B, _, L_F, L_H, L_W = latent_shape
|
||||
_, _, A_F, A_H, A_W = action_shape
|
||||
|
||||
latent_seq_id = (
|
||||
torch.arange(B)[:, None, None, None]
|
||||
.expand(-1, L_F // patch_size[0], L_H // patch_size[1], L_W // patch_size[2])
|
||||
.flatten()
|
||||
)
|
||||
action_seq_id = torch.arange(B)[:, None, None, None].expand(-1, A_F, A_H, A_W).flatten()
|
||||
seq_ids = torch.cat([latent_seq_id] * 2 + [action_seq_id] * 2)
|
||||
|
||||
latent_frame_id = (
|
||||
torch.arange(L_F)[None, :, None, None]
|
||||
.expand(B, -1, L_H // patch_size[1], L_W // patch_size[2])[None]
|
||||
.flatten()
|
||||
)
|
||||
action_frame_id = torch.arange(A_F)[None, :, None, None].expand(B, -1, A_H, A_W)[None].flatten()
|
||||
frame_ids = torch.cat(
|
||||
[latent_frame_id // chunk_size * 2] * 2 + [action_frame_id // chunk_size * 2 + 1] * 2
|
||||
)
|
||||
|
||||
noise_ids = torch.cat(
|
||||
[
|
||||
torch.zeros_like(latent_frame_id),
|
||||
torch.ones_like(latent_frame_id),
|
||||
torch.zeros_like(action_frame_id),
|
||||
torch.ones_like(action_frame_id),
|
||||
]
|
||||
)
|
||||
|
||||
seq_ids = F.pad(seq_ids, (0, padded_length), value=-1)
|
||||
frame_ids = F.pad(frame_ids, (0, padded_length), value=-1)
|
||||
noise_ids = F.pad(noise_ids, (0, padded_length), value=-1)
|
||||
|
||||
mask_mod = FlexAttnFunc._get_mask_mod(
|
||||
seq_ids.long().to(device), frame_ids.long().to(device), noise_ids.long().to(device), window_size
|
||||
)
|
||||
block_mask = FlexAttnFunc.compiled_create_block_mask(
|
||||
mask_mod, 1, 1, len(seq_ids), len(seq_ids), device=device, _compile=True
|
||||
)
|
||||
FlexAttnFunc.attention_mask = block_mask
|
||||
|
||||
text_seq_ids = torch.arange(B)[:, None].expand(-1, 512).flatten()
|
||||
mask_mod_cross = FlexAttnFunc._get_cross_mask_mod(
|
||||
seq_ids.long().to(device), text_seq_ids.long().to(device)
|
||||
)
|
||||
block_mask_cross = FlexAttnFunc.compiled_create_block_mask(
|
||||
mask_mod_cross, 1, 1, len(seq_ids), len(text_seq_ids), device=device, _compile=True
|
||||
)
|
||||
FlexAttnFunc.cross_attention_mask = block_mask_cross
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _get_cross_mask_mod(seq_ids, text_seq_ids):
|
||||
def seq_mask(b, h, q_idx, kv_idx):
|
||||
return (
|
||||
(seq_ids[q_idx] == text_seq_ids[kv_idx]) & (seq_ids[q_idx] >= 0) & (text_seq_ids[kv_idx] >= 0)
|
||||
)
|
||||
|
||||
return seq_mask
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _get_mask_mod(seq_ids, frame_ids, noise_ids, window_size):
|
||||
def seq_mask(b, h, q_idx, kv_idx):
|
||||
return (seq_ids[q_idx] == seq_ids[kv_idx]) & (seq_ids[q_idx] >= 0) & (seq_ids[kv_idx] >= 0)
|
||||
|
||||
def block_causal_mask(b, h, q_idx, kv_idx):
|
||||
return frame_ids[kv_idx] <= frame_ids[q_idx]
|
||||
|
||||
def block_causal_mask_exclude_self(b, h, q_idx, kv_idx):
|
||||
return frame_ids[kv_idx] < frame_ids[q_idx]
|
||||
|
||||
def block_self_mask(b, h, q_idx, kv_idx):
|
||||
return frame_ids[kv_idx] == frame_ids[q_idx]
|
||||
|
||||
def clean2clean_mask(b, h, q_idx, kv_idx):
|
||||
return (noise_ids[q_idx] == 1) & (noise_ids[kv_idx] == 1)
|
||||
|
||||
def noise2clean_mask(b, h, q_idx, kv_idx):
|
||||
return (noise_ids[q_idx] == 0) & (noise_ids[kv_idx] == 1)
|
||||
|
||||
def noise2noise_mask(b, h, q_idx, kv_idx):
|
||||
return (noise_ids[q_idx] == 0) & (noise_ids[kv_idx] == 0)
|
||||
|
||||
def block_window_mask(b, h, q_idx, kv_idx, window_size: int):
|
||||
return (frame_ids[q_idx] - frame_ids[kv_idx]).abs() <= window_size
|
||||
|
||||
mask_list = []
|
||||
mask_list.append(and_masks(clean2clean_mask, block_causal_mask))
|
||||
mask_list.append(and_masks(noise2clean_mask, block_causal_mask_exclude_self))
|
||||
mask_list.append(and_masks(noise2noise_mask, block_self_mask))
|
||||
mask = or_masks(*mask_list)
|
||||
mask = and_masks(mask, seq_mask)
|
||||
mask = and_masks(mask, partial(block_window_mask, window_size=window_size))
|
||||
return mask
|
||||
|
||||
|
||||
__all__ = ["FlexAttnFunc"]
|
||||
@@ -0,0 +1,514 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The dual-stream Wan2.2 video-action transformer backbone for LingBot-VA.
|
||||
|
||||
Vendored and lightly adapted from the upstream LingBot-VA repository
|
||||
(https://github.com/Robbyant/lingbot-va, ``wan_va/modules/model.py``).
|
||||
|
||||
The model keeps the diffusers ``ModelMixin``/``ConfigMixin`` mixins so the original
|
||||
sharded ``transformer/`` checkpoint can be loaded with ``from_pretrained`` during
|
||||
conversion, but in LeRobot it is owned as a plain ``nn.Module`` sub-component of
|
||||
:class:`~lerobot.policies.lingbot_va.modeling_lingbot_va.LingBotVAPolicy`. State-dict
|
||||
parameter names are preserved verbatim so conversion is near-identity.
|
||||
"""
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.attention import FeedForward
|
||||
from diffusers.models.embeddings import (
|
||||
PixArtAlphaTextProjection,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import FP32LayerNorm
|
||||
from einops import rearrange
|
||||
|
||||
from .wan_attention import WanAttention, WanRotaryPosEmbed
|
||||
|
||||
__all__ = ["WanTransformer3DModel", "WanTransformerBlock", "WanTimeTextImageEmbedding"]
|
||||
|
||||
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
time_freq_dim,
|
||||
time_proj_dim,
|
||||
text_embed_dim,
|
||||
pos_embed_seq_len,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.timesteps_proj = Timesteps(
|
||||
num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||
|
||||
def forward(self, timestep: torch.Tensor, dtype=None):
|
||||
B, L = timestep.shape
|
||||
timestep = timestep.reshape(-1)
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).to(dtype=dtype)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
return temb.reshape(B, L, -1), timestep_proj.reshape(B, L, -1)
|
||||
|
||||
|
||||
class WanTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6,
|
||||
attn_mode: str = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
self.attn_mode = attn_mode
|
||||
|
||||
# 1. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
cross_attention_dim_head=None,
|
||||
attn_mode=attn_mode,
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.attn2 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
cross_attention_dim_head=dim // num_heads,
|
||||
attn_mode=attn_mode,
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
||||
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
rotary_emb,
|
||||
update_cache=0,
|
||||
cache_name="pos",
|
||||
) -> torch.Tensor:
|
||||
temb_scale_shift_table = self.scale_shift_table[None] + temb.float()
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = rearrange(
|
||||
temb_scale_shift_table, "b l n c -> b n l c"
|
||||
).chunk(6, dim=1)
|
||||
shift_msa = shift_msa.squeeze(1)
|
||||
scale_msa = scale_msa.squeeze(1)
|
||||
gate_msa = gate_msa.squeeze(1)
|
||||
c_shift_msa = c_shift_msa.squeeze(1)
|
||||
c_scale_msa = c_scale_msa.squeeze(1)
|
||||
c_gate_msa = c_gate_msa.squeeze(1)
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1.0 + scale_msa) + shift_msa).type_as(
|
||||
hidden_states
|
||||
)
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
norm_hidden_states,
|
||||
norm_hidden_states,
|
||||
rotary_emb,
|
||||
update_cache=update_cache,
|
||||
cache_name=cache_name,
|
||||
)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
update_cache=0,
|
||||
cache_name=cache_name,
|
||||
)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1.0 + c_scale_msa) + c_shift_msa).type_as(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""Dual-stream (video + action) Wan2.2 DiT backbone with autoregressive KV caching."""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = [
|
||||
"patch_embedding_mlp",
|
||||
"condition_embedder",
|
||||
"condition_embedder_action",
|
||||
"norm",
|
||||
]
|
||||
_no_split_modules = ["WanTransformerBlock"]
|
||||
_keep_in_fp32_modules = [
|
||||
"time_embedder",
|
||||
"scale_shift_table",
|
||||
"scale_shift_table_action",
|
||||
"norm1",
|
||||
"action_norm1",
|
||||
"text_norm1",
|
||||
"norm2",
|
||||
"action_norm2",
|
||||
"text_norm2",
|
||||
"norm3",
|
||||
"action_norm3",
|
||||
"text_norm3",
|
||||
]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
in_channels=48,
|
||||
out_channels=48,
|
||||
action_dim=30,
|
||||
text_dim=4096,
|
||||
freq_dim=256,
|
||||
ffn_dim=14336,
|
||||
num_layers=30,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-06,
|
||||
rope_max_seq_len=1024,
|
||||
pos_embed_seq_len=None,
|
||||
attn_mode="torch",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
||||
self.patch_embedding_mlp = nn.Linear(
|
||||
in_channels * patch_size[0] * patch_size[1] * patch_size[2], inner_dim
|
||||
)
|
||||
self.action_embedder = nn.Linear(action_dim, inner_dim)
|
||||
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||
dim=inner_dim,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
pos_embed_seq_len=pos_embed_seq_len,
|
||||
)
|
||||
self.condition_embedder_action = deepcopy(self.condition_embedder)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanTransformerBlock(
|
||||
inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps, attn_mode=attn_mode
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
||||
self.action_proj_out = nn.Linear(inner_dim, action_dim)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# KV-cache management for autoregressive streaming inference
|
||||
# ------------------------------------------------------------------
|
||||
def clear_cache(self, cache_name):
|
||||
for block in self.blocks:
|
||||
block.attn1.clear_cache(cache_name)
|
||||
|
||||
def clear_pred_cache(self, cache_name):
|
||||
for block in self.blocks:
|
||||
block.attn1.clear_pred_cache(cache_name)
|
||||
|
||||
def create_empty_cache(
|
||||
self,
|
||||
cache_name,
|
||||
attn_window,
|
||||
latent_token_per_chunk,
|
||||
action_token_per_chunk,
|
||||
device,
|
||||
dtype,
|
||||
batch_size,
|
||||
):
|
||||
total_tolen = (attn_window // 2) * latent_token_per_chunk + (
|
||||
attn_window // 2
|
||||
) * action_token_per_chunk
|
||||
for block in self.blocks:
|
||||
block.attn1.init_kv_cache(
|
||||
cache_name,
|
||||
total_tolen,
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
device,
|
||||
dtype,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Embedding helpers (shared by train + inference paths)
|
||||
# ------------------------------------------------------------------
|
||||
def _input_embed(self, latents, input_type="latent"):
|
||||
if input_type == "latent":
|
||||
hidden_states = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self.patch_size[0],
|
||||
p2=self.patch_size[1],
|
||||
p3=self.patch_size[2],
|
||||
)
|
||||
hidden_states = self.patch_embedding_mlp(hidden_states)
|
||||
elif input_type == "action":
|
||||
hidden_states = rearrange(latents, "b c f h w -> b (f h w) c")
|
||||
hidden_states = self.action_embedder(hidden_states)
|
||||
elif input_type == "text":
|
||||
hidden_states = self.condition_embedder.text_embedder(latents)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {input_type}")
|
||||
return hidden_states
|
||||
|
||||
def _time_embed(self, timesteps, H, W, dtype, action_mode=False):
|
||||
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
|
||||
latent_time_steps = torch.repeat_interleave(
|
||||
timesteps, (H // pach_scale_h) * (W // pach_scale_w), dim=1
|
||||
)
|
||||
current_condition_embedder = (
|
||||
self.condition_embedder_action if action_mode else self.condition_embedder
|
||||
)
|
||||
temb, timestep_proj = current_condition_embedder(latent_time_steps, dtype=dtype)
|
||||
timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
|
||||
return temb, timestep_proj
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dual-stream training forward (flow matching). Requires attn_mode='flex'.
|
||||
# ------------------------------------------------------------------
|
||||
def forward_train(self, input_dict):
|
||||
from .wan_flex_attention import FlexAttnFunc
|
||||
|
||||
input_dict["latent_dict"]["noisy_latents"] = input_dict["latent_dict"]["noisy_latents"].to(
|
||||
torch.bfloat16
|
||||
)
|
||||
input_dict["latent_dict"]["latent"] = input_dict["latent_dict"]["latent"].to(torch.bfloat16)
|
||||
input_dict["action_dict"]["noisy_latents"] = input_dict["action_dict"]["noisy_latents"].to(
|
||||
torch.bfloat16
|
||||
)
|
||||
input_dict["action_dict"]["latent"] = input_dict["action_dict"]["latent"].to(torch.bfloat16)
|
||||
|
||||
latent_dict = input_dict["latent_dict"]
|
||||
action_dict = input_dict["action_dict"]
|
||||
batch_size = latent_dict["noisy_latents"].shape[0]
|
||||
|
||||
latent_hidden_states = self._input_embed(latent_dict["noisy_latents"], input_type="latent").flatten(
|
||||
0, 1
|
||||
)[None]
|
||||
action_hidden_states = self._input_embed(action_dict["noisy_latents"], input_type="action").flatten(
|
||||
0, 1
|
||||
)[None]
|
||||
text_hidden_states = self._input_embed(latent_dict["text_emb"], input_type="text")
|
||||
|
||||
text_hidden_states = text_hidden_states.flatten(0, 1)[None]
|
||||
|
||||
condition_latent_hidden_states = self._input_embed(
|
||||
latent_dict["latent"], input_type="latent"
|
||||
).flatten(0, 1)[None]
|
||||
condition_action_hidden_states = self._input_embed(
|
||||
action_dict["latent"], input_type="action"
|
||||
).flatten(0, 1)[None]
|
||||
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
latent_hidden_states,
|
||||
condition_latent_hidden_states,
|
||||
action_hidden_states,
|
||||
condition_action_hidden_states,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
latent_grid_id = latent_dict["grid_id"].permute(1, 0, 2).flatten(1)[None]
|
||||
action_grid_id = action_dict["grid_id"].permute(1, 0, 2).flatten(1)[None]
|
||||
full_grid_id = torch.cat([latent_grid_id] * 2 + [action_grid_id] * 2, dim=2)
|
||||
|
||||
rotary_emb = self.rope(full_grid_id)[:, :, None]
|
||||
|
||||
latent_time_steps = torch.cat(
|
||||
[latent_dict["timesteps"].flatten(0, 1), latent_dict["cond_timesteps"].flatten(0, 1)]
|
||||
)[None]
|
||||
action_time_steps = torch.cat(
|
||||
[action_dict["timesteps"].flatten(0, 1), action_dict["cond_timesteps"].flatten(0, 1)]
|
||||
)[None]
|
||||
latent_temb, latent_timestep_proj = self._time_embed(
|
||||
latent_time_steps,
|
||||
latent_dict["noisy_latents"].shape[-2],
|
||||
latent_dict["noisy_latents"].shape[-1],
|
||||
dtype=hidden_states.dtype,
|
||||
action_mode=False,
|
||||
)
|
||||
action_temb, action_timestep_proj = self._time_embed(
|
||||
action_time_steps,
|
||||
action_dict["noisy_latents"].shape[-2],
|
||||
action_dict["noisy_latents"].shape[-1],
|
||||
dtype=hidden_states.dtype,
|
||||
action_mode=True,
|
||||
)
|
||||
temb = torch.cat([latent_temb, action_temb], dim=1)
|
||||
timestep_proj = torch.cat([latent_timestep_proj, action_timestep_proj], dim=1)
|
||||
|
||||
total_length = hidden_states.shape[1]
|
||||
padded_length = (128 - total_length % 128) % 128
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, padded_length))
|
||||
rotary_emb = F.pad(rotary_emb, (0, 0, 0, 0, 0, padded_length))
|
||||
temb = F.pad(temb, (0, 0, 0, padded_length))
|
||||
timestep_proj = F.pad(timestep_proj, (0, 0, 0, 0, 0, padded_length))
|
||||
|
||||
split_list = [
|
||||
latent_hidden_states.shape[1],
|
||||
condition_latent_hidden_states.shape[1],
|
||||
action_hidden_states.shape[1],
|
||||
condition_action_hidden_states.shape[1],
|
||||
padded_length,
|
||||
]
|
||||
|
||||
FlexAttnFunc.init_mask(
|
||||
latent_dict["noisy_latents"].shape,
|
||||
action_dict["noisy_latents"].shape,
|
||||
padded_length,
|
||||
input_dict["chunk_size"],
|
||||
window_size=input_dict["window_size"],
|
||||
patch_size=self.patch_size,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states = block(
|
||||
hidden_states, text_hidden_states, timestep_proj, rotary_emb, update_cache=False
|
||||
)
|
||||
temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
|
||||
shift, scale = rearrange(temb_scale_shift_table, "b l n c -> b n l c").chunk(2, dim=1)
|
||||
shift = shift.to(hidden_states.device).squeeze(1)
|
||||
scale = scale.to(hidden_states.device).squeeze(1)
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1.0 + scale) + shift).type_as(hidden_states)
|
||||
latent_hidden_states, _, action_hidden_states, _, _ = torch.split(hidden_states, split_list, dim=1)
|
||||
latent_hidden_states = self.proj_out(latent_hidden_states)
|
||||
latent_hidden_states = rearrange(
|
||||
latent_hidden_states, "1 (b l) (n c) -> b (l n) c", n=math.prod(self.patch_size), b=batch_size
|
||||
)
|
||||
action_hidden_states = self.action_proj_out(action_hidden_states)
|
||||
action_hidden_states = rearrange(action_hidden_states, "1 (b l) c -> b l c", b=batch_size)
|
||||
|
||||
return latent_hidden_states, action_hidden_states
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Single-stream inference forward (one denoising step for one stream)
|
||||
# ------------------------------------------------------------------
|
||||
def forward(
|
||||
self,
|
||||
input_dict,
|
||||
update_cache=0,
|
||||
cache_name="pos",
|
||||
action_mode=False,
|
||||
train_mode=False,
|
||||
):
|
||||
if train_mode:
|
||||
return self.forward_train(input_dict)
|
||||
if action_mode: # action input emb
|
||||
latent_hidden_states = rearrange(input_dict["noisy_latents"], "b c f h w -> b (f h w) c")
|
||||
latent_hidden_states = self.action_embedder(latent_hidden_states) # B L1 C
|
||||
else: # latent input emb
|
||||
latent_hidden_states = rearrange(
|
||||
input_dict["noisy_latents"],
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self.patch_size[0],
|
||||
p2=self.patch_size[1],
|
||||
p3=self.patch_size[2],
|
||||
)
|
||||
latent_hidden_states = self.patch_embedding_mlp(latent_hidden_states)
|
||||
text_hidden_states = self.condition_embedder.text_embedder(input_dict["text_emb"]) # B L2 C
|
||||
|
||||
latent_grid_id = input_dict["grid_id"]
|
||||
rotary_emb = self.rope(latent_grid_id)[:, :, None] # 1 L 1 C
|
||||
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
|
||||
|
||||
latent_time_steps = torch.repeat_interleave(
|
||||
input_dict["timesteps"],
|
||||
(input_dict["noisy_latents"].shape[-2] // pach_scale_h)
|
||||
* (input_dict["noisy_latents"].shape[-1] // pach_scale_w),
|
||||
dim=1,
|
||||
) # L
|
||||
current_condition_embedder = (
|
||||
self.condition_embedder_action if action_mode else self.condition_embedder
|
||||
)
|
||||
temb, timestep_proj = current_condition_embedder(latent_time_steps, dtype=latent_hidden_states.dtype)
|
||||
timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
|
||||
|
||||
for block in self.blocks:
|
||||
latent_hidden_states = block(
|
||||
latent_hidden_states,
|
||||
text_hidden_states,
|
||||
timestep_proj,
|
||||
rotary_emb,
|
||||
update_cache=update_cache,
|
||||
cache_name=cache_name,
|
||||
)
|
||||
temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
|
||||
shift, scale = rearrange(temb_scale_shift_table, "b l n c -> b n l c").chunk(2, dim=1)
|
||||
shift = shift.to(latent_hidden_states.device).squeeze(1)
|
||||
scale = scale.to(latent_hidden_states.device).squeeze(1)
|
||||
latent_hidden_states = (self.norm_out(latent_hidden_states.float()) * (1.0 + scale) + shift).type_as(
|
||||
latent_hidden_states
|
||||
)
|
||||
|
||||
if action_mode:
|
||||
latent_hidden_states = self.action_proj_out(latent_hidden_states)
|
||||
else:
|
||||
latent_hidden_states = self.proj_out(latent_hidden_states)
|
||||
latent_hidden_states = rearrange(
|
||||
latent_hidden_states, "b l (n c) -> b (l n) c", n=math.prod(self.patch_size)
|
||||
)
|
||||
|
||||
return latent_hidden_states
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Grid-id / patch utilities for the LingBot-VA autoregressive inference loop.
|
||||
|
||||
Vendored verbatim from the upstream LingBot-VA repository
|
||||
(https://github.com/Robbyant/lingbot-va, ``wan_va/utils/utils.py``).
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["get_mesh_id", "data_seq_to_patch"]
|
||||
|
||||
|
||||
def data_seq_to_patch(patch_size, data_seq, latent_num_frames, latent_height, latent_width, batch_size=1):
|
||||
"""Reshape a flattened patch sequence back into a ``(B, C, F, H, W)`` latent grid."""
|
||||
p_t, p_h, p_w = patch_size
|
||||
post_patch_num_frames = latent_num_frames // p_t
|
||||
post_patch_height = latent_height // p_h
|
||||
post_patch_width = latent_width // p_w
|
||||
|
||||
data_patch = data_seq.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
data_patch = data_patch.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
data_patch = data_patch.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return data_patch
|
||||
|
||||
|
||||
def get_mesh_id(f, h, w, t, f_w=1, f_shift=0, action=False):
|
||||
"""Build the (frame, height, width, stream) grid ids used to index the rotary embedding."""
|
||||
f_idx = torch.arange(f_shift, f + f_shift) * f_w
|
||||
h_idx = torch.arange(h)
|
||||
w_idx = torch.arange(w)
|
||||
ff, hh, ww = torch.meshgrid(f_idx, h_idx, w_idx, indexing="ij")
|
||||
if action:
|
||||
ff_offset = (torch.ones([h]).cumsum(0) / (h + 1)).view(1, -1, 1)
|
||||
ff = ff + ff_offset
|
||||
hh = torch.ones_like(hh) * -1
|
||||
ww = torch.ones_like(ww) * -1
|
||||
|
||||
grid_id = torch.cat([ff.unsqueeze(0), hh.unsqueeze(0), ww.unsqueeze(0)], dim=0).flatten(1)
|
||||
grid_id = torch.cat([grid_id, torch.full_like(grid_id[:1], t)], dim=0)
|
||||
return grid_id
|
||||
@@ -0,0 +1,120 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Thin helpers around the stock diffusers ``AutoencoderKLWan`` (Wan2.2, ``z_dim=48``).
|
||||
|
||||
The VAE class itself is NOT vendored — it lives in ``diffusers>=0.36``. This module
|
||||
provides:
|
||||
* loaders for the VAE / text encoder / tokenizer / transformer sub-checkpoints,
|
||||
* the streaming-encoder wrapper used for autoregressive frame-by-frame VAE encoding
|
||||
(it caches the causal-conv state across chunks),
|
||||
* latent (de)normalization helpers using the VAE's ``latents_mean`` / ``latents_std``.
|
||||
|
||||
Vendored and adapted from ``wan_va/modules/utils.py`` upstream.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"WanVAEStreamingWrapper",
|
||||
"load_vae",
|
||||
"load_text_encoder",
|
||||
"load_tokenizer",
|
||||
"normalize_latents",
|
||||
"denormalize_latents",
|
||||
"patchify",
|
||||
]
|
||||
|
||||
|
||||
def load_vae(vae_path, torch_dtype, torch_device):
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
vae = AutoencoderKLWan.from_pretrained(vae_path, torch_dtype=torch_dtype)
|
||||
return vae.to(torch_device)
|
||||
|
||||
|
||||
def load_text_encoder(text_encoder_path, torch_dtype, torch_device):
|
||||
from transformers import UMT5EncoderModel
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(text_encoder_path, torch_dtype=torch_dtype)
|
||||
return text_encoder.to(torch_device)
|
||||
|
||||
|
||||
def load_tokenizer(tokenizer_path):
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
return T5TokenizerFast.from_pretrained(tokenizer_path)
|
||||
|
||||
|
||||
def patchify(x, patch_size):
|
||||
if patch_size is None or patch_size == 1:
|
||||
return x
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
x = x.view(
|
||||
batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size
|
||||
)
|
||||
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
|
||||
x = x.view(
|
||||
batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Apply ``(x - mean) * std`` channel-wise (note: upstream passes ``1/std`` as ``latents_std``)."""
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
|
||||
latents = ((latents.float() - latents_mean) * latents_std).to(latents)
|
||||
return latents
|
||||
|
||||
|
||||
def denormalize_latents(latents: torch.Tensor, latents_mean, latents_std, z_dim) -> torch.Tensor:
|
||||
"""Inverse of the normalization applied at encode time, for VAE decoding of predicted latents."""
|
||||
mean = torch.tensor(latents_mean).view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
inv_std = 1.0 / torch.tensor(latents_std).view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
return latents / inv_std + mean
|
||||
|
||||
|
||||
class WanVAEStreamingWrapper:
|
||||
"""Wraps an ``AutoencoderKLWan`` encoder to support causal streaming encoding across chunks."""
|
||||
|
||||
def __init__(self, vae_model):
|
||||
self.vae = vae_model
|
||||
self.encoder = vae_model.encoder
|
||||
self.quant_conv = vae_model.quant_conv
|
||||
|
||||
if hasattr(self.vae, "_cached_conv_counts"):
|
||||
self.enc_conv_num = self.vae._cached_conv_counts["encoder"]
|
||||
else:
|
||||
count = 0
|
||||
for m in self.encoder.modules():
|
||||
if m.__class__.__name__ == "WanCausalConv3d":
|
||||
count += 1
|
||||
self.enc_conv_num = count
|
||||
|
||||
self.clear_cache()
|
||||
|
||||
def clear_cache(self):
|
||||
self.feat_cache = [None] * self.enc_conv_num
|
||||
|
||||
def encode_chunk(self, x_chunk):
|
||||
if hasattr(self.vae.config, "patch_size") and self.vae.config.patch_size is not None:
|
||||
x_chunk = patchify(x_chunk, self.vae.config.patch_size)
|
||||
feat_idx = [0]
|
||||
out = self.encoder(x_chunk, feat_cache=self.feat_cache, feat_idx=feat_idx)
|
||||
enc = self.quant_conv(out)
|
||||
return enc
|
||||
@@ -105,6 +105,7 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
predicted_frames_callback: Callable[[PreTrainedPolicy], None] | None = None,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -134,6 +135,9 @@ def rollout(
|
||||
are returned optionally because they typically take more memory to cache. Defaults to False.
|
||||
render_callback: Optional rendering callback to be used after the environments are reset, and after
|
||||
every step.
|
||||
predicted_frames_callback: Optional callback invoked after every ``select_action`` with the policy
|
||||
itself. World-model policies (e.g. LingBot-VA) stash their decoded predicted video frames on
|
||||
``policy.last_predicted_frames``; this lets the caller collect them to save predicted-video MP4s.
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
@@ -184,6 +188,8 @@ def rollout(
|
||||
observation = preprocessor(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
if predicted_frames_callback is not None:
|
||||
predicted_frames_callback(policy)
|
||||
action = postprocessor(action)
|
||||
|
||||
action_transition = {ACTION: action}
|
||||
@@ -273,6 +279,7 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
save_predicted_video: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -291,6 +298,11 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0 and not videos_dir:
|
||||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
# World-model policies (e.g. LingBot-VA) opt into predicted-video saving via their config.
|
||||
save_predicted_video = save_predicted_video or bool(
|
||||
getattr(getattr(policy, "config", None), "save_predicted_video", False)
|
||||
)
|
||||
|
||||
if not isinstance(policy, PreTrainedPolicy):
|
||||
exc = ValueError(
|
||||
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
|
||||
@@ -334,6 +346,21 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
video_paths: list[str] = []
|
||||
|
||||
if save_predicted_video:
|
||||
if not videos_dir:
|
||||
raise ValueError("If save_predicted_video is True, videos_dir must be provided.")
|
||||
predicted_video_paths: list[str] = []
|
||||
n_predicted_rendered = 0
|
||||
|
||||
# Collects the policy's decoded predicted-video frames across a rollout (world-model policies only).
|
||||
def collect_predicted_frames(policy: PreTrainedPolicy):
|
||||
frames = getattr(policy, "last_predicted_frames", None)
|
||||
if frames is not None:
|
||||
pred_frames.append(
|
||||
np.asarray(frames.detach().to("cpu")) if hasattr(frames, "detach") else np.asarray(frames)
|
||||
)
|
||||
policy.last_predicted_frames = None
|
||||
|
||||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
@@ -345,6 +372,9 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
if save_predicted_video:
|
||||
pred_frames: list[np.ndarray] = []
|
||||
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
else:
|
||||
@@ -361,6 +391,7 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
predicted_frames_callback=collect_predicted_frames if save_predicted_video else None,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -426,6 +457,25 @@ def eval_policy(
|
||||
threads.append(thread)
|
||||
n_episodes_rendered += 1
|
||||
|
||||
# Maybe save the policy's predicted (imagined) video for this batch's rollout.
|
||||
if save_predicted_video and len(pred_frames) > 0:
|
||||
# pred_frames is a list of [F, H, W, C] uint8 stacks emitted on chunk refills; concat over time.
|
||||
predicted_video = np.concatenate(pred_frames, axis=0)
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
predicted_video_path = videos_dir / f"pred_episode_{n_predicted_rendered}.mp4"
|
||||
predicted_video_paths.append(str(predicted_video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(
|
||||
str(predicted_video_path),
|
||||
predicted_video,
|
||||
env.unwrapped.metadata["render_fps"],
|
||||
),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
n_predicted_rendered += 1
|
||||
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
||||
)
|
||||
@@ -469,6 +519,9 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
info["video_paths"] = video_paths
|
||||
|
||||
if save_predicted_video:
|
||||
info["predicted_video_paths"] = predicted_video_paths
|
||||
|
||||
return info
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
|
||||
def make_config(**overrides) -> LingBotVAConfig:
|
||||
kwargs = {"device": "cpu"}
|
||||
kwargs.update(overrides)
|
||||
return LingBotVAConfig(**kwargs)
|
||||
|
||||
|
||||
def test_registered_in_choice_registry() -> None:
|
||||
assert "lingbot_va" in PreTrainedConfig.get_known_choices()
|
||||
assert PreTrainedConfig.get_choice_class("lingbot_va") is LingBotVAConfig
|
||||
|
||||
|
||||
def test_type_property() -> None:
|
||||
assert make_config().type == "lingbot_va"
|
||||
|
||||
|
||||
def test_chunk_size_and_action_steps() -> None:
|
||||
cfg = make_config(frame_chunk_size=4, action_per_frame=4)
|
||||
assert cfg.chunk_size == 16
|
||||
assert cfg.n_action_steps == 16
|
||||
assert cfg.action_delta_indices == list(range(16))
|
||||
assert cfg.observation_delta_indices is None
|
||||
assert cfg.reward_delta_indices is None
|
||||
|
||||
|
||||
def test_optimizer_and_scheduler_presets() -> None:
|
||||
cfg = make_config()
|
||||
opt = cfg.get_optimizer_preset()
|
||||
assert opt.lr == cfg.optimizer_lr
|
||||
sched = cfg.get_scheduler_preset()
|
||||
assert sched.num_warmup_steps == cfg.scheduler_warmup_steps
|
||||
|
||||
|
||||
def test_validate_features_sets_action_feature() -> None:
|
||||
cfg = make_config()
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
|
||||
cfg.output_features = {}
|
||||
cfg.validate_features()
|
||||
assert ACTION in cfg.output_features
|
||||
assert cfg.output_features[ACTION].shape == (len(cfg.used_action_channel_ids),)
|
||||
|
||||
|
||||
def test_validate_features_no_visual_raises() -> None:
|
||||
cfg = make_config()
|
||||
cfg.input_features = {}
|
||||
cfg.output_features = {}
|
||||
with pytest.raises(ValueError, match="at least one visual input feature"):
|
||||
cfg.validate_features()
|
||||
|
||||
|
||||
def test_invalid_attn_mode_raises() -> None:
|
||||
with pytest.raises(ValueError, match="attn_mode"):
|
||||
make_config(attn_mode="banana")
|
||||
|
||||
|
||||
def test_quantile_length_mismatch_raises() -> None:
|
||||
with pytest.raises(ValueError, match="action_q01"):
|
||||
make_config(used_action_channel_ids=[0, 1, 2], action_q01=[0.0, 0.0], action_q99=[1.0, 1.0, 1.0])
|
||||
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
|
||||
|
||||
def test_make_policy_config_returns_lingbot_va() -> None:
|
||||
cfg = make_policy_config("lingbot_va", device="cpu")
|
||||
assert isinstance(cfg, LingBotVAConfig)
|
||||
|
||||
|
||||
def test_get_policy_class_resolves_lazily() -> None:
|
||||
# Importing the policy class pulls in diffusers (Wan2.2 stack); skip if unavailable.
|
||||
pytest.importorskip("diffusers")
|
||||
pytest.importorskip("transformers")
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
|
||||
cls = get_policy_class("lingbot_va")
|
||||
assert cls.name == "lingbot_va"
|
||||
assert cls.config_class is LingBotVAConfig
|
||||
|
||||
|
||||
def test_convert_build_config_libero() -> None:
|
||||
pytest.importorskip("diffusers")
|
||||
from lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints import build_config
|
||||
|
||||
cfg = build_config("libero", wan_pretrained_path="dummy/path", dtype="float32")
|
||||
assert cfg.height == 128 and cfg.width == 128
|
||||
assert cfg.used_action_channel_ids == list(range(7))
|
||||
# validate_features (called inside build_config) must have populated the action feature.
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
assert cfg.output_features[ACTION].shape == (7,)
|
||||
assert len(cfg.obs_cam_keys) == 2
|
||||
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pure-torch unit tests for the vendored LingBot-VA helper modules (no diffusers needed)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.lingbot_va.schedulers import FlowMatchScheduler
|
||||
from lerobot.policies.lingbot_va.wan_utils import data_seq_to_patch, get_mesh_id
|
||||
|
||||
|
||||
def test_flow_match_scheduler_timesteps_monotone_decreasing() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
assert sch.timesteps.shape == (20,)
|
||||
diffs = sch.timesteps[1:] - sch.timesteps[:-1]
|
||||
assert torch.all(diffs <= 0) # decreasing
|
||||
|
||||
|
||||
def test_flow_match_scheduler_step_preserves_shape() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
sample = torch.zeros(1, 48, 4, 8, 16)
|
||||
out = sch.step(torch.ones_like(sample), sch.timesteps[0], sample)
|
||||
assert out.shape == sample.shape
|
||||
|
||||
|
||||
def test_flow_match_scheduler_add_noise() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
sample = torch.randn(1, 48, 4, 8, 16)
|
||||
noise = torch.randn_like(sample)
|
||||
noisy = sch.add_noise(sample, noise, sch.timesteps[:4], t_dim=2)
|
||||
assert noisy.shape == sample.shape
|
||||
|
||||
|
||||
def test_get_mesh_id_latent_shape() -> None:
|
||||
grid = get_mesh_id(4, 8, 16, 0, 1, 0)
|
||||
assert grid.shape == (4, 4 * 8 * 16) # (f, h, w, stream) x tokens
|
||||
|
||||
|
||||
def test_get_mesh_id_action_shape() -> None:
|
||||
grid = get_mesh_id(4, 4, 1, 1, 1, 0, action=True)
|
||||
assert grid.shape == (4, 4 * 4 * 1)
|
||||
# Action rows for h/w are sentinel -1.
|
||||
assert torch.all(grid[1] < 0)
|
||||
assert torch.all(grid[2] < 0)
|
||||
|
||||
|
||||
def test_data_seq_to_patch_roundtrip_shape() -> None:
|
||||
b, f, h, w, c = 1, 4, 8, 16, 48
|
||||
seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c)
|
||||
out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b)
|
||||
assert out.shape == (b, c, f, h, w)
|
||||
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.processor_lingbot_va import (
|
||||
LingBotVAActionUnnormalizeStep,
|
||||
make_lingbot_va_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
def _make_config() -> LingBotVAConfig:
|
||||
cfg = LingBotVAConfig(device="cpu")
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
|
||||
cfg.output_features = {}
|
||||
cfg.validate_features()
|
||||
return cfg
|
||||
|
||||
|
||||
def test_action_unnormalize_inverts_quantile_norm() -> None:
|
||||
q01 = [-1.0, -0.5, 0.0]
|
||||
q99 = [1.0, 0.5, 2.0]
|
||||
step = LingBotVAActionUnnormalizeStep(action_q01=q01, action_q99=q99)
|
||||
|
||||
# Forward (the policy-side) quantile normalization: (x - q01) / (q99 - q01 + eps) * 2 - 1.
|
||||
q01_t = torch.tensor(q01)
|
||||
q99_t = torch.tensor(q99)
|
||||
raw = torch.tensor([[0.3, 0.1, 1.0]])
|
||||
normed = (raw - q01_t) / (q99_t - q01_t + 1e-6) * 2.0 - 1.0
|
||||
|
||||
recovered = step.action(normed)
|
||||
assert torch.allclose(recovered, raw, atol=1e-4)
|
||||
|
||||
|
||||
def test_action_unnormalize_config_roundtrip() -> None:
|
||||
step = LingBotVAActionUnnormalizeStep(action_q01=[0.0, 1.0], action_q99=[2.0, 3.0])
|
||||
cfg = step.get_config()
|
||||
assert cfg == {"action_q01": [0.0, 1.0], "action_q99": [2.0, 3.0]}
|
||||
rebuilt = LingBotVAActionUnnormalizeStep(**cfg)
|
||||
assert rebuilt.action_q01 == step.action_q01
|
||||
assert rebuilt.action_q99 == step.action_q99
|
||||
|
||||
|
||||
def test_make_pre_post_processors_names_and_steps() -> None:
|
||||
cfg = _make_config()
|
||||
pre, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
assert pre.name == POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
assert post.name == POLICY_POSTPROCESSOR_DEFAULT_NAME
|
||||
# The postprocessor must contain the dedicated quantile unnormalize step.
|
||||
assert any(isinstance(s, LingBotVAActionUnnormalizeStep) for s in post.steps)
|
||||
|
||||
|
||||
def test_postprocessor_applies_unnormalization() -> None:
|
||||
cfg = _make_config()
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
# A normalized action of all -1 should map back to q01.
|
||||
normed = torch.full((1, len(cfg.used_action_channel_ids)), -1.0)
|
||||
out = post(normed)
|
||||
assert torch.allclose(out, torch.tensor(cfg.action_q01).unsqueeze(0), atol=1e-4)
|
||||
Reference in New Issue
Block a user