diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 0d4e36172..e62d2b3e5 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -67,6 +67,8 @@ title: VLA-JEPA - local: eo1 title: EO-1 + - local: lingbot_va + title: LingBot-VA - local: groot title: NVIDIA GR00T N1.5 - local: xvla diff --git a/docs/source/lingbot_va.mdx b/docs/source/lingbot_va.mdx new file mode 100644 index 000000000..af3a97fa8 --- /dev/null +++ b/docs/source/lingbot_va.mdx @@ -0,0 +1,120 @@ +# LingBot-VA + +LingBot-VA is an **autoregressive video-action world-model policy** built on the **Wan2.2** +video-diffusion stack. It interleaves, in one autoregressive sequence, the prediction of +future **video latents** and **robot actions** ("VA" = Video-Action). The LeRobot +integration wires LingBot-VA into the standard training, evaluation and processor +interfaces. + +## Model Overview + +LingBot-VA is a **dual-stream "mixture-of-transformers"**: a video/latent stream +(`patch_embedding_mlp → blocks → proj_out`) and an action stream +(`action_embedder → blocks → action_proj_out`) share the same 30 transformer blocks and +text conditioning. Actions are produced by the dedicated `action_proj_out` head — they are +**not** decoded from predicted pixels, though video and action are co-trained. + +| Component | Class | Role | +|---|---|---| +| DiT backbone (trainable) | `WanTransformer3DModel` | ~5B-param dual-stream transformer (the only weights stored in the LeRobot checkpoint). | +| VAE (frozen) | `AutoencoderKLWan` | Wan2.2 VAE, `z_dim=48`. Lazy-pulled from the source repo. | +| Text encoder (frozen) | `UMT5EncoderModel` | UMT5-XXL, `d_model=4096`. Lazy-pulled from the source repo. | + +At inference the policy runs an autoregressive loop per chunk: it denoises the video-latent +stream (CFG, ~20 steps) and the action stream (~50 steps) with two independent +flow-matching schedulers, maintaining a KV cache across chunks. Real observed keyframes are +fed back into the KV cache as the chunk is executed (closed-loop world modeling). + +### What the LeRobot Integration Covers + +- Standard `policy.type=lingbot_va` configuration through LeRobot. +- Checkpoint conversion from the released HuggingFace checkpoints. +- Autoregressive dual-stream inference behind the standard `select_action` interface + (single-environment eval, `--eval.batch_size=1`). +- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training. +- Evaluation with `lerobot-eval` on the LIBERO benchmark. + +Training (the flow-matching dual-stream loss + latent dataset) is part of a follow-up +training port and is not yet wired into `lerobot-train`. + +## Installation + +1. Install LeRobot by following the [Installation Guide](./installation). +2. Install the LingBot-VA extra (brings in `diffusers>=0.36` for the Wan2.2 stack): + +```bash +pip install -e ".[lingbot_va]" +# For LIBERO evaluation (Linux only): +pip install -e ".[lingbot_va,libero]" +``` + +## Checkpoint Conversion + +The released checkpoints are diffusers-style directories +(`robbyant/lingbot-va-base`, `robbyant/lingbot-va-posttrain-robotwin`, +`robbyant/lingbot-va-posttrain-libero-long`). Convert one to LeRobot format with: + +```bash +python -m lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints \ + --checkpoint robbyant/lingbot-va-posttrain-libero-long \ + --variant libero \ + --output_dir outputs/lingbot_va_libero_long +``` + +**Packaging:** only the trainable ~5B transformer is stored in the LeRobot +`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are **lazily pulled** from +`config.wan_pretrained_path` at load time (defaults to the source repo). Pass +`--bundle-frozen` to copy those sub-folders next to the converted checkpoint instead. + +Run conversion on a Linux machine with a CUDA GPU and enough RAM/VRAM to materialize the +transformer. + +## Evaluation (LIBERO) + +```bash +lerobot-eval \ + --policy.path=outputs/lingbot_va_libero_long \ + --env.type=libero --env.task=libero_10 \ + --eval.n_episodes=50 --eval.batch_size=1 \ + --output_dir=outputs/eval/lingbot_va_libero +``` + +LingBot-VA's streaming inference (KV cache + observed-keyframe feedback) is implemented for +single-environment eval; use `--eval.batch_size=1`. + +### Saving predicted (imagined) videos + +Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicted video +latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos. +The same flag works for the periodic eval during `lerobot-train`. + +## Inference Hyperparameters (LIBERO) + +| Key | Value | +|---|---| +| height × width | 128 × 128 | +| cameras | `observation.images.image` (agentview), `observation.images.image2` (eye-in-hand) | +| action channels used | 0–6 (7-DoF arm + gripper) | +| action_per_frame / frame_chunk_size | 4 / 4 | +| attn_window | 30 | +| video / action denoising steps | 20 / 50 | +| guidance_scale / action_guidance_scale | 5 / 1 | +| snr_shift / action_snr_shift | 5.0 / 0.05 | + +These are the defaults of `LingBotVAConfig`; override any of them via `--policy.=...`. + +## Notes & Limitations + +- **Correctness gate:** matching the upstream LIBERO success rate requires validating the + converted checkpoint on a GPU and tensor-diffing intermediate activations against the + upstream implementation. The most sensitive parts are the action quantile normalization, + the camera ordering, the `action_per_frame`/`frame_chunk_size` alignment, and `attn_mode`. +- **Attention backend:** inference uses the `torch` SDPA backend (always available). The + `flashattn` and `flex` backends are optional; `flex` is only needed for training. +- **Model size:** the DiT is ~5B params and the frozen VAE+UMT5 add ~20 GB; inference needs + roughly 18–24 GB of VRAM. + +## License + +LingBot-VA is released under Apache-2.0. See the +[upstream repository](https://github.com/Robbyant/lingbot-va). diff --git a/pyproject.toml b/pyproject.toml index 2b4c22f12..c1c11c99c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,7 +146,8 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] scipy-dep = ["scipy>=1.14.0,<2.0.0"] -diffusers-dep = ["diffusers>=0.27.2,<0.36.0"] +diffusers-dep = ["diffusers>=0.27.2,<0.37.0"] +imageio-dep = ["imageio[ffmpeg]>=2.34.0,<3.0.0"] qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"] matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster. pyserial-dep = ["pyserial>=3.5,<4.0"] @@ -218,6 +219,10 @@ xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"] +# LingBot-VA needs the Wan2.2 stack (AutoencoderKLWan z_dim=48 + WanTransformer3DModel config schema), +# which only exists in diffusers>=0.36. Pin the floor explicitly so a standalone `lerobot[lingbot_va]` +# install can't resolve to a pre-Wan2.2 diffusers via the looser diffusers-dep floor. +lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]"] # Features async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] @@ -284,6 +289,8 @@ all = [ "lerobot[xvla]", "lerobot[hilserl]", "lerobot[vla_jepa]", + "lerobot[eo1]", + "lerobot[lingbot_va]", "lerobot[async]", "lerobot[dev]", "lerobot[test]", @@ -375,6 +382,9 @@ ignore = [ # E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect "src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"] "src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original +# Vendored Wan2.2 / LingBot-VA model code uses tensor-dimension names (B, F, H, W) and `F` for +# torch.nn.functional; keep the upstream naming to make diffing against upstream tractable. +"src/lerobot/policies/lingbot_va/**" = ["N803", "N806", "N812", "SIM102"] [tool.ruff.lint.isort] combine-as-imports = true diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py index 250650089..74111e7ef 100644 --- a/src/lerobot/optim/schedulers.py +++ b/src/lerobot/optim/schedulers.py @@ -83,6 +83,28 @@ class VQBeTSchedulerConfig(LRSchedulerConfig): return LambdaLR(optimizer, lr_lambda, -1) +@LRSchedulerConfig.register_subclass("constant_with_warmup") +@dataclass +class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig): + """Linear warmup followed by a constant learning rate. + + Mirrors the ``warmup_constant_lambda`` used by LingBot-VA (upstream ``wan_va/train.py``): + the LR ramps linearly from 0 to the peak over ``num_warmup_steps`` steps, then stays flat. + """ + + num_warmup_steps: int = 1000 + + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + warmup_steps = self.num_warmup_steps or 0 + + def lr_lambda(current_step): + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda, -1) + + @LRSchedulerConfig.register_subclass("cosine_decay_with_warmup") @dataclass class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 68d23c9ca..0e6e5949e 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig from .groot.configuration_groot import GrootConfig as GrootConfig +from .lingbot_va.configuration_lingbot_va import LingBotVAConfig as LingBotVAConfig from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config as PI0Config @@ -44,6 +45,7 @@ __all__ = [ "EO1Config", "GaussianActorConfig", "GrootConfig", + "LingBotVAConfig", "MolmoAct2Config", "MultiTaskDiTConfig", "PI0Config", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index a42b38ba4..136f00058 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig from .eo1.configuration_eo1 import EO1Config from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig from .groot.configuration_groot import GrootConfig +from .lingbot_va.configuration_lingbot_va import LingBotVAConfig from .molmoact2.configuration_molmoact2 import MolmoAct2Config from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config @@ -162,6 +163,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy return VLAJEPAPolicy + elif name == "lingbot_va": + from .lingbot_va.modeling_lingbot_va import LingBotVAPolicy + + return LingBotVAPolicy else: try: return _get_policy_cls_from_policy_name(name=name) @@ -218,6 +223,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return MolmoAct2Config(**kwargs) elif policy_type == "vla_jepa": return VLAJEPAConfig(**kwargs) + elif policy_type == "lingbot_va": + return LingBotVAConfig(**kwargs) else: try: config_cls = PreTrainedConfig.get_choice_class(policy_type) @@ -448,6 +455,14 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, LingBotVAConfig): + from .lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors + + processors = make_lingbot_va_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + else: try: processors = _make_processors_from_policy_config( diff --git a/src/lerobot/policies/lingbot_va/README.md b/src/lerobot/policies/lingbot_va/README.md new file mode 120000 index 000000000..2ec3c82af --- /dev/null +++ b/src/lerobot/policies/lingbot_va/README.md @@ -0,0 +1 @@ +../../../../docs/source/lingbot_va.mdx \ No newline at end of file diff --git a/src/lerobot/policies/lingbot_va/__init__.py b/src/lerobot/policies/lingbot_va/__init__.py new file mode 100644 index 000000000..30092b2ff --- /dev/null +++ b/src/lerobot/policies/lingbot_va/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: ``LingBotVAPolicy`` (and the Wan transformer it owns) imports ``diffusers`` as a +# hard dependency at class-definition time (it subclasses diffusers' ModelMixin/ConfigMixin). +# To keep base ``import lerobot`` working without the optional ``lingbot_va`` extra, the +# policy is exposed lazily via module ``__getattr__`` — the heavy import only happens when +# ``LingBotVAPolicy`` is actually accessed (mirroring the lazy import in policies/factory.py). +from .configuration_lingbot_va import LingBotVAConfig +from .processor_lingbot_va import make_lingbot_va_pre_post_processors + +__all__ = ["LingBotVAConfig", "LingBotVAPolicy", "make_lingbot_va_pre_post_processors"] + + +def __getattr__(name): + if name == "LingBotVAPolicy": + from .modeling_lingbot_va import LingBotVAPolicy + + return LingBotVAPolicy + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/lerobot/policies/lingbot_va/configuration_lingbot_va.py b/src/lerobot/policies/lingbot_va/configuration_lingbot_va.py new file mode 100644 index 000000000..04c903e3f --- /dev/null +++ b/src/lerobot/policies/lingbot_va/configuration_lingbot_va.py @@ -0,0 +1,198 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for the LingBot-VA policy. + +LingBot-VA is an autoregressive video-action world-model policy built on the Wan2.2 +video-diffusion stack. It interleaves prediction of future video latents and robot +actions in a single dual-stream transformer. See ``docs/source/lingbot_va.mdx`` and the +upstream repository (https://github.com/Robbyant/lingbot-va). + +Defaults below match the upstream LIBERO configuration (``wan_va/configs/va_libero_cfg.py``) +and the ``transformer/config.json`` of the released checkpoints. +""" + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.utils.constants import ACTION + +# Upstream LIBERO action-normalization quantiles (single 7-DoF arm + gripper). +# Verbatim from wan_va/configs/va_libero_cfg.py (channels 0-6 of a 30-dim action space). +LIBERO_ACTION_Q01 = [ + -0.6589285731315613, + -0.84375, + -0.9375, + -0.12107142806053162, + -0.15964286029338837, + -0.26571428775787354, + -1.0, +] +LIBERO_ACTION_Q99 = [ + 0.8999999761581421, + 0.8544642925262451, + 0.9375, + 0.17142857611179352, + 0.1842857152223587, + 0.34392857551574707, + 1.0, +] + + +@PreTrainedConfig.register_subclass("lingbot_va") +@dataclass +class LingBotVAConfig(PreTrainedConfig): + """Configuration for the native LingBot-VA policy integration in LeRobot.""" + + # ── Wan transformer architecture (from transformer/config.json) ── + patch_size: tuple[int, int, int] = (1, 2, 2) + num_attention_heads: int = 24 + attention_head_dim: int = 128 + in_channels: int = 48 + out_channels: int = 48 + action_dim: int = 30 + text_dim: int = 4096 + freq_dim: int = 256 + ffn_dim: int = 14336 + num_layers: int = 30 + cross_attn_norm: bool = True + eps: float = 1e-6 + rope_max_seq_len: int = 1024 + # "flex" is supported for training only and needs a recent torch build. Inference uses + # "torch" SDPA (always available) or, optionally, "flashattn". + attn_mode: str = "torch" + + # ── Frozen sub-models (VAE + UMT5 text encoder + tokenizer) ── + # These heavy frozen weights (~20 GB) are NOT bundled into the LeRobot safetensors + # checkpoint (only the trainable ~5B transformer is). They are lazily pulled from this + # HF repo / local directory at policy-init time. The directory must contain the + # diffusers-style ``vae/``, ``text_encoder/`` and ``tokenizer/`` sub-folders. + wan_pretrained_path: str = "robbyant/lingbot-va-posttrain-libero-long" + # dtype used for the transformer / VAE / text-encoder weights at inference. + dtype: str = "bfloat16" # one of "bfloat16", "float16", "float32" + + # ── Observation cameras (order matters: latents are concatenated on width) ── + # Defaults match the LIBERO env feature keys (agentview -> image, eye-in-hand -> image2). + obs_cam_keys: list[str] = field( + default_factory=lambda: ["observation.images.image", "observation.images.image2"] + ) + + # ── Inference hyperparameters (LIBERO defaults) ── + n_obs_steps: int = 1 + height: int = 128 + width: int = 128 + action_per_frame: int = 4 + frame_chunk_size: int = 4 + attn_window: int = 30 + num_inference_steps: int = 20 + video_exec_step: int = -1 + action_num_inference_steps: int = 50 + guidance_scale: float = 5.0 + action_guidance_scale: float = 1.0 + snr_shift: float = 5.0 + action_snr_shift: float = 0.05 + max_sequence_length: int = 512 # UMT5 prompt length + + # Subset of the 30-d action space actually used by the benchmark (LIBERO = 7-DoF). + used_action_channel_ids: list[int] = field(default_factory=lambda: list(range(7))) + # Fixed quantiles for action (un)normalization on the *used* channels. + action_q01: list[float] = field(default_factory=lambda: list(LIBERO_ACTION_Q01)) + action_q99: list[float] = field(default_factory=lambda: list(LIBERO_ACTION_Q99)) + + # Opt-in: VAE-decode the predicted video latents and stash them on + # ``self.last_predicted_frames`` so eval/train can save predicted-video MP4s. + save_predicted_video: bool = False + + # ── Normalization (handled internally / via custom steps, hence IDENTITY here) ── + # Images are scaled to [-1, 1] and VAE-encoded inside the policy; actions are + # quantile-(un)normalized by dedicated processor steps using the fixed quantiles above. + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.IDENTITY, + } + ) + + # ── Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py) ── + optimizer_lr: float = 1e-5 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-4 + optimizer_grad_clip_norm: float = 1.0 + scheduler_warmup_steps: int = 1000 + + def __post_init__(self): + super().__post_init__() + if self.attn_mode not in ("torch", "flashattn", "flex"): + raise ValueError(f"attn_mode must be one of 'torch', 'flashattn', 'flex'; got {self.attn_mode!r}") + if len(self.action_q01) != len(self.used_action_channel_ids) or len(self.action_q99) != len( + self.used_action_channel_ids + ): + raise ValueError( + "action_q01 / action_q99 must each have one entry per used_action_channel_ids " + f"({len(self.used_action_channel_ids)}); got {len(self.action_q01)} / {len(self.action_q99)}." + ) + + @property + def chunk_size(self) -> int: + """Number of single-step actions produced per autoregressive chunk.""" + return self.frame_chunk_size * self.action_per_frame + + @property + def n_action_steps(self) -> int: + """Number of actions executed before refilling (the whole chunk).""" + return self.chunk_size + + def validate_features(self) -> None: + image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL] + if not image_features: + raise ValueError( + "LingBot-VA requires at least one visual input feature. " + "No features of type FeatureType.VISUAL found in input_features." + ) + if ACTION not in self.output_features: + self.output_features[ACTION] = PolicyFeature( + type=FeatureType.ACTION, shape=(len(self.used_action_channel_ids),) + ) + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self) -> LRSchedulerConfig | None: + # Upstream uses a linear warmup followed by a constant LR (warmup_constant_lambda). + from lerobot.optim.schedulers import ConstantWithWarmupSchedulerConfig + + return ConstantWithWarmupSchedulerConfig(num_warmup_steps=self.scheduler_warmup_steps) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list[int]: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/lingbot_va/convert_lingbot_va_checkpoints.py b/src/lerobot/policies/lingbot_va/convert_lingbot_va_checkpoints.py new file mode 100644 index 000000000..86a49f51d --- /dev/null +++ b/src/lerobot/policies/lingbot_va/convert_lingbot_va_checkpoints.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert a released LingBot-VA HuggingFace checkpoint to LeRobot format. + +The released checkpoints are diffusers-style directories with ``transformer/``, ``vae/``, +``text_encoder/`` and ``tokenizer/`` sub-folders. This script: + + 1. loads the (sharded) ``transformer/`` weights with the vendored ``WanTransformer3DModel``; + 2. builds a :class:`LingBotVAConfig` for the target benchmark variant; + 3. instantiates a :class:`LingBotVAPolicy` and copies the transformer weights into it + (near-identity: the only key change is the ``transformer.`` prefix); + 4. saves the LeRobot policy (``model.safetensors`` + ``config.json``) and its processors. + +Packaging decision: only the trainable ~5B transformer is bundled into the LeRobot +``model.safetensors``. The frozen VAE + UMT5 text encoder + tokenizer (~20 GB) are NOT +copied; instead ``config.wan_pretrained_path`` records where to lazily pull them from at +load time (defaults to the source repo/dir). Pass ``--bundle-frozen`` to additionally copy +those sub-folders next to the converted checkpoint and point ``wan_pretrained_path`` at it. + +Example (LIBERO-Long, the LIBERO eval gate): + + python -m lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints \ + --checkpoint robbyant/lingbot-va-posttrain-libero-long \ + --variant libero \ + --output_dir outputs/lingbot_va_libero_long + +Requires a CUDA GPU with enough RAM/VRAM to materialize the transformer; run on Linux. +""" + +import argparse +import shutil +from pathlib import Path + +import torch + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig +from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy +from lerobot.policies.lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors +from lerobot.policies.lingbot_va.wan_transformer import WanTransformer3DModel +from lerobot.utils.constants import ACTION, OBS_IMAGES + +# Per-benchmark variant presets (camera keys + action layout). Values mirror the upstream +# configs (wan_va/configs/va_*_cfg.py). +VARIANTS = { + "libero": { + "obs_cam_keys": [f"{OBS_IMAGES}.image", f"{OBS_IMAGES}.image2"], + "height": 128, + "width": 128, + "action_per_frame": 4, + "frame_chunk_size": 4, + "attn_window": 30, + "num_inference_steps": 20, + "action_num_inference_steps": 50, + "guidance_scale": 5.0, + "action_guidance_scale": 1.0, + "snr_shift": 5.0, + "action_snr_shift": 0.05, + "used_action_channel_ids": list(range(7)), + # 7-DoF: agentview + eye-in-hand, single arm. Quantiles are the config defaults. + "image_shape": (3, 256, 256), + }, + "robotwin": { + "obs_cam_keys": [ + f"{OBS_IMAGES}.cam_high", + f"{OBS_IMAGES}.cam_left_wrist", + f"{OBS_IMAGES}.cam_right_wrist", + ], + "height": 256, + "width": 320, + "action_per_frame": 16, + "frame_chunk_size": 2, + "attn_window": 72, + "num_inference_steps": 25, + "action_num_inference_steps": 50, + "guidance_scale": 5.0, + "action_guidance_scale": 1.0, + "snr_shift": 5.0, + "action_snr_shift": 1.0, + # RoboTwin is dual-arm; set the used channels / quantiles to match the deployed config. + "used_action_channel_ids": list(range(14)), + "image_shape": (3, 256, 256), + }, +} + + +def _transformer_dir(checkpoint: str) -> str: + """Return the path/repo that ``WanTransformer3DModel.from_pretrained`` should read.""" + p = Path(checkpoint) + if p.is_dir(): + return str(p / "transformer") + return checkpoint # HF repo id; use subfolder kwarg below + + +def load_source_transformer(checkpoint: str, dtype: torch.dtype) -> WanTransformer3DModel: + p = Path(checkpoint) + if p.is_dir(): + return WanTransformer3DModel.from_pretrained( + str(p / "transformer"), torch_dtype=dtype, attn_mode="torch" + ) + return WanTransformer3DModel.from_pretrained( + checkpoint, subfolder="transformer", torch_dtype=dtype, attn_mode="torch" + ) + + +def build_config(variant: str, wan_pretrained_path: str, dtype: str) -> LingBotVAConfig: + preset = VARIANTS[variant] + n_used = len(preset["used_action_channel_ids"]) + kwargs = { + "wan_pretrained_path": wan_pretrained_path, + "dtype": dtype, + "obs_cam_keys": preset["obs_cam_keys"], + "height": preset["height"], + "width": preset["width"], + "action_per_frame": preset["action_per_frame"], + "frame_chunk_size": preset["frame_chunk_size"], + "attn_window": preset["attn_window"], + "num_inference_steps": preset["num_inference_steps"], + "action_num_inference_steps": preset["action_num_inference_steps"], + "guidance_scale": preset["guidance_scale"], + "action_guidance_scale": preset["action_guidance_scale"], + "snr_shift": preset["snr_shift"], + "action_snr_shift": preset["action_snr_shift"], + "used_action_channel_ids": preset["used_action_channel_ids"], + "device": "cpu", + } + if variant != "libero": + # LIBERO keeps the config default quantiles; other variants need their own. Until the + # exact per-channel quantiles are wired in, use a neutral [-1, 1] mapping (no rescale). + kwargs["action_q01"] = [-1.0] * n_used + kwargs["action_q99"] = [1.0] * n_used + cfg = LingBotVAConfig(**kwargs) + # Populate input/output features (cameras + action) so validate_features passes. + img_shape = preset["image_shape"] + cfg.input_features = { + k: PolicyFeature(type=FeatureType.VISUAL, shape=img_shape) for k in preset["obs_cam_keys"] + } + cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(n_used,))} + cfg.validate_features() + return cfg + + +def convert( + checkpoint: str, variant: str, output_dir: str, dtype: str, bundle_frozen: bool, push_to: str | None +): + torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[dtype] + out = Path(output_dir) + out.mkdir(parents=True, exist_ok=True) + + # Decide where frozen modules will be pulled from at load time. + if bundle_frozen: + wan_pretrained_path = str(out) + _copy_frozen_subfolders(checkpoint, out) + else: + wan_pretrained_path = checkpoint + + print(f"Building LingBot-VA config for variant '{variant}' (frozen modules from: {wan_pretrained_path})") + cfg = build_config(variant, wan_pretrained_path, dtype) + + print("Loading source transformer weights ...") + src = load_source_transformer(checkpoint, torch_dtype) + src_sd = src.state_dict() + + print("Instantiating LingBotVAPolicy and copying transformer weights ...") + # Build the policy without triggering frozen-module download by constructing directly. + policy = LingBotVAPolicy(cfg) + # Near-identity remap: source transformer keys -> policy "transformer.*". + remapped = {f"transformer.{k}": v for k, v in src_sd.items()} + missing, unexpected = policy.load_state_dict(remapped, strict=False) + _log_load_keys(missing, unexpected) + policy = policy.to(torch_dtype) + + print(f"Saving converted policy to {out}") + policy.save_pretrained(out) + + preprocessor, postprocessor = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None) + preprocessor.save_pretrained(out) + postprocessor.save_pretrained(out) + + if push_to: + print(f"Pushing to the Hub: {push_to}") + policy.push_to_hub(push_to) + preprocessor.push_to_hub(push_to) + postprocessor.push_to_hub(push_to) + + print("Done.") + + +def _copy_frozen_subfolders(checkpoint: str, out: Path): + p = Path(checkpoint) + if not p.is_dir(): + from huggingface_hub import snapshot_download + + p = Path(snapshot_download(checkpoint, allow_patterns=["vae/*", "text_encoder/*", "tokenizer/*"])) + for sub in ("vae", "text_encoder", "tokenizer"): + src_sub = p / sub + if src_sub.is_dir(): + shutil.copytree(src_sub, out / sub, dirs_exist_ok=True) + print(f" bundled {sub}/") + + +def _log_load_keys(missing, unexpected): + # The source transformer should account for every "transformer.*" key in the policy. + if missing: + print( + f" [load_state_dict] {len(missing)} missing keys (expected: none for transformer). Sample: {missing[:5]}" + ) + if unexpected: + print(f" [load_state_dict] {len(unexpected)} unexpected keys. Sample: {unexpected[:5]}") + if not missing and not unexpected: + print(" [load_state_dict] perfect match (near-identity remap).") + + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--checkpoint", required=True, help="HF repo id or local diffusers-style directory.") + parser.add_argument("--variant", required=True, choices=sorted(VARIANTS.keys())) + parser.add_argument("--output_dir", required=True) + parser.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"]) + parser.add_argument( + "--bundle-frozen", + action="store_true", + help="Copy the frozen vae/text_encoder/tokenizer next to the checkpoint instead of lazy-pulling.", + ) + parser.add_argument( + "--push_to_hub", default=None, help="Optional HF repo id to push the converted policy to." + ) + args = parser.parse_args() + convert( + checkpoint=args.checkpoint, + variant=args.variant, + output_dir=args.output_dir, + dtype=args.dtype, + bundle_frozen=args.bundle_frozen, + push_to=args.push_to_hub, + ) + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py b/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py new file mode 100644 index 000000000..e2550c2cc --- /dev/null +++ b/src/lerobot/policies/lingbot_va/modeling_lingbot_va.py @@ -0,0 +1,582 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LingBot-VA policy: an autoregressive video-action world model on the Wan2.2 stack. + +The sampling loop is a faithful re-implementation of the upstream streaming server +(``wan_va/wan_va_server.py``) and LIBERO client (``evaluation/libero/client.py``), adapted +to LeRobot's ``select_action`` interface: + + * the trainable dual-stream transformer is owned as a sub-module and round-trips in the + single ``model.safetensors`` checkpoint; + * the frozen Wan VAE + UMT5 text encoder + tokenizer are *lazily pulled* from + ``config.wan_pretrained_path`` (not bundled), so the LeRobot checkpoint stays small; + * ``predict_action_chunk`` runs one autoregressive chunk (video stream then action + stream, each with CFG and its own flow-matching scheduler) and updates the KV cache; + * ``select_action`` drains a per-step action queue and records the real observed + keyframes that are fed back into the KV cache when the queue is refilled. + +NOTE: matching the upstream LIBERO success rate is the Phase-5 correctness gate and must be +validated on a CUDA GPU with the converted checkpoint (tensor-diff against upstream on +identical inputs). The streaming path is written for single-environment eval +(``--eval.batch_size=1``). +""" + +from collections import deque + +import torch +import torch.nn.functional as F +from torch import Tensor + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.import_utils import require_package + +from .configuration_lingbot_va import LingBotVAConfig +from .schedulers import FlowMatchScheduler +from .wan_transformer import WanTransformer3DModel +from .wan_utils import data_seq_to_patch, get_mesh_id +from .wan_vae import WanVAEStreamingWrapper, denormalize_latents, load_text_encoder, load_tokenizer, load_vae + + +def _torch_dtype(name: str) -> torch.dtype: + return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[name] + + +class LingBotVAPolicy(PreTrainedPolicy): + """LeRobot wrapper for the LingBot-VA autoregressive video-action world model.""" + + config_class = LingBotVAConfig + name = "lingbot_va" + + def __init__(self, config: LingBotVAConfig, **kwargs): + require_package("diffusers", extra="lingbot_va") + require_package("transformers", extra="lingbot_va") + super().__init__(config) + config.validate_features() + self.config = config + + self.dtype = _torch_dtype(config.dtype) + + # Trainable dual-stream transformer (the only sub-module saved in the LeRobot checkpoint). + self.transformer = WanTransformer3DModel( + patch_size=tuple(config.patch_size), + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + out_channels=config.out_channels, + action_dim=config.action_dim, + text_dim=config.text_dim, + freq_dim=config.freq_dim, + ffn_dim=config.ffn_dim, + num_layers=config.num_layers, + cross_attn_norm=config.cross_attn_norm, + eps=config.eps, + rope_max_seq_len=config.rope_max_seq_len, + attn_mode=config.attn_mode, + ) + + # Frozen modules are stored OUTSIDE the nn.Module registry (plain dict) so they are + # neither saved into model.safetensors nor moved by ``.to()``. They are lazily loaded + # from ``config.wan_pretrained_path`` the first time inference runs. + self._frozen: dict = {} + + self.last_predicted_frames: Tensor | None = None + self.reset() + + # ------------------------------------------------------------------ + # Frozen-module lazy loading (VAE + UMT5 + tokenizer) + # ------------------------------------------------------------------ + def _ensure_frozen_modules(self): + if self._frozen: + return + import os + + path = self.config.wan_pretrained_path + device = self.config.device + + # Support both local diffusers-style dirs (with vae/ text_encoder/ tokenizer/ sub-folders) + # and HF repo ids (loaders accept a subfolder kwarg, omitted here = repo root layout). + if os.path.isdir(path): + vae_path, te_path, tok_path = ( + os.path.join(path, n) for n in ("vae", "text_encoder", "tokenizer") + ) + else: + vae_path = te_path = tok_path = path + + vae = load_vae(vae_path, torch_dtype=self.dtype, torch_device=device) + text_encoder = load_text_encoder(te_path, torch_dtype=self.dtype, torch_device=device) + tokenizer = load_tokenizer(tok_path) + self._frozen = { + "vae": vae.eval(), + "streaming_vae": WanVAEStreamingWrapper(vae), + "text_encoder": text_encoder.eval(), + "tokenizer": tokenizer, + } + + @property + def _vae(self): + return self._frozen["vae"] + + @property + def _streaming_vae(self): + return self._frozen["streaming_vae"] + + # ------------------------------------------------------------------ + # PreTrainedPolicy API + # ------------------------------------------------------------------ + def get_optim_params(self) -> dict: + # Only the transformer is trainable; the VAE / text encoder stay frozen. + return self.transformer.parameters() + + def reset(self): + """Reset all per-episode streaming state (KV cache, queues, frame counter).""" + cfg = self.config + self._action_queue: deque = deque(maxlen=cfg.n_action_steps) + self._obs_buffer: list = [] # keyframe camera tensors observed during the current chunk + self._executed_actions: Tensor | None = ( + None # last chunk's actions (model-normalized) for KV feedback + ) + self._steps_since_refill = 0 + self._frame_st_id = 0 + self._first_chunk = True + self._prompt: str | None = None + self._prompt_embeds = None + self._negative_prompt_embeds = None + self.last_predicted_frames = None + self._use_cfg = (cfg.guidance_scale > 1) or (cfg.action_guidance_scale > 1) + # Two independent flow-matching schedulers (video latent + action streams). + self._scheduler = FlowMatchScheduler(shift=cfg.snr_shift, sigma_min=0.0, extra_one_step=True) + self._action_scheduler = FlowMatchScheduler( + shift=cfg.action_snr_shift, sigma_min=0.0, extra_one_step=True + ) + self._scheduler.set_timesteps(1000, training=True) + self._action_scheduler.set_timesteps(1000, training=True) + self._cache_initialised = False + # Clear KV cache on the (already-built) transformer, if present. + if hasattr(self, "transformer"): + self.transformer.clear_cache("pos") + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: + """Training loss. Implemented in the LingBot-VA training PR (Phase 7). + + The flow-matching dual-stream loss needs the pre-extracted latent dataset + (see ``LatentLeRobotDataset`` upstream) and ``attn_mode='flex'``; it is intentionally + not part of this inference-focused integration. + """ + raise NotImplementedError( + "LingBot-VA training (flow-matching dual-stream loss) is part of the training port " + "(Phase 7 / PR #2) and is not implemented in this inference integration. " + "Use this policy for evaluation / inference." + ) + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + """Return one action, refilling the chunk (and feeding back observed keyframes) as needed.""" + self.eval() + self._ensure_frozen_modules() + self._maybe_init_prompt(batch) + + # Record the current observation as a keyframe at every frame boundary so that, when the + # queue empties, ``predict_action_chunk`` can feed the real observed frames back into the + # KV cache (mirroring the upstream ``compute_kv_cache`` call in the LIBERO client loop). + # We skip ``steps_since_refill == 0`` (the obs that conditioned the current chunk): only + # frames observed *after* executing each frame's actions are fed back. + if self._steps_since_refill > 0 and self._steps_since_refill % self.config.action_per_frame == 0: + self._obs_buffer.append(self._encode_obs(batch)) + + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch) # [B, chunk_size, n_used] + # queue holds per-step actions: shape [chunk_size, B, n_used] + self._action_queue.extend(actions.transpose(0, 1)) + self._steps_since_refill = 0 + + self._steps_since_refill += 1 + return self._action_queue.popleft() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor: + """Run one autoregressive chunk and return actions ``[B, chunk_size, n_used]`` (normalized).""" + self.eval() + self._ensure_frozen_modules() + self._maybe_init_prompt(batch) + + is_first = self._first_chunk + if is_first: + init_latent = self._encode_obs(batch) + self._init_latent = init_latent + self._init_streaming_cache(init_latent) + self._obs_buffer = [] # frame 0 (the init obs) conditions the chunk; it is not fed back + actions, latents = self._infer(init_latent, frame_st_id=0) + self._first_chunk = False + else: + # Feed the real observed keyframes + the executed actions back into the KV cache. + self._compute_kv_cache(self._obs_buffer, self._executed_actions) + self._obs_buffer = [] + actions, latents = self._infer(None, frame_st_id=self._frame_st_id) + + # actions: [B, action_dim, F, action_per_frame, 1] (model-normalized). Keep for KV feedback. + self._executed_actions = actions + + if self.config.save_predicted_video: + self.last_predicted_frames = self._decode_predicted_video(latents) + + # On the first chunk, frame 0 is the conditioning frame (already "known"): the upstream + # LIBERO client skips it (start_idx=1), so we drop the first frame's actions here. + used = self.config.used_action_channel_ids + a = actions[:, used] # [B, n_used, F, action_per_frame, 1] + if is_first: + a = a[:, :, 1:] # drop frame 0 -> (F-1) frames of actions + a = a.squeeze(-1).flatten(2) # [B, n_used, n_steps] + a = a.transpose(1, 2).contiguous() # [B, n_steps, n_used] + return a.to(torch.float32) + + # ------------------------------------------------------------------ + # Prompt / text encoding + # ------------------------------------------------------------------ + def _maybe_init_prompt(self, batch): + if self._prompt_embeds is not None: + return + task = batch.get("task") + prompt = task[0] if isinstance(task, list | tuple) else task + self._prompt = prompt or "" + self._prompt_embeds, self._negative_prompt_embeds = self._encode_prompt(self._prompt) + + def _get_t5_prompt_embeds(self, prompt, max_sequence_length): + from diffusers.pipelines.wan.pipeline_wan import prompt_clean + + tokenizer = self._frozen["tokenizer"] + text_encoder = self._frozen["text_encoder"] + device = self.config.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + te_device = next(text_encoder.parameters()).device + prompt_embeds = text_encoder(text_input_ids.to(te_device), mask.to(te_device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens, strict=False)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], + dim=0, + ) + return prompt_embeds.to(device) + + def _encode_prompt(self, prompt): + max_len = self.config.max_sequence_length + prompt_embeds = self._get_t5_prompt_embeds(prompt, max_len) + negative_prompt_embeds = None + if self._use_cfg: + negative_prompt_embeds = self._get_t5_prompt_embeds("", max_len) + return prompt_embeds, negative_prompt_embeds + + # ------------------------------------------------------------------ + # Observation (image) encoding -> normalized video latents + # ------------------------------------------------------------------ + def _camera_tensor(self, batch, key): + """Return a single-frame camera tensor [B, C, 1, H, W] resized + scaled to [-1, 1].""" + img = batch[key] + if img.dim() == 3: # [C, H, W] + img = img.unsqueeze(0) + # LeRobot images arrive as float in [0, 1], shape [B, C, H, W]. + img = img.to(self.config.device, torch.float32) + img = F.interpolate( + img, size=(self.config.height, self.config.width), mode="bilinear", align_corners=False + ) + img = img * 2.0 - 1.0 + return img.unsqueeze(2).to(self.dtype) # [B, C, F=1, H, W] + + @torch.no_grad() + def _encode_obs(self, batch) -> Tensor: + """VAE-encode all configured cameras of the current obs and concat latents on width.""" + videos = [self._camera_tensor(batch, k) for k in self.config.obs_cam_keys] + videos = torch.cat(videos, dim=0) # [num_cam, C, F, H, W] + vae_device = next(self._vae.parameters()).device + enc_out = self._streaming_vae.encode_chunk(videos.to(vae_device).to(self.dtype)) + mu, _logvar = torch.chunk(enc_out, 2, dim=1) + latents_mean = torch.tensor(self._vae.config.latents_mean).to(mu.device) + latents_std = torch.tensor(self._vae.config.latents_std).to(mu.device) + # Note: upstream passes 1/std so the op is (x - mean) * (1/std). + mean = latents_mean.view(1, -1, 1, 1, 1) + inv_std = (1.0 / latents_std).view(1, -1, 1, 1, 1) + mu_norm = ((mu.float() - mean) * inv_std).to(mu) + # Concatenate the per-camera latents along width. + video_latent = torch.cat(mu_norm.split(1, dim=0), dim=-1) + return video_latent.to(self.config.device) + + # ------------------------------------------------------------------ + # KV cache management + # ------------------------------------------------------------------ + @property + def _latent_hw(self): + h = self.config.height // 16 + w = (self.config.width // 16) * len(self.config.obs_cam_keys) + return h, w + + def _init_streaming_cache(self, init_latent): + cfg = self.config + latent_h, latent_w = self._latent_hw + p = cfg.patch_size + latent_token_per_chunk = (cfg.frame_chunk_size * latent_h * latent_w) // (p[0] * p[1] * p[2]) + action_token_per_chunk = cfg.frame_chunk_size * cfg.action_per_frame + self.transformer.create_empty_cache( + "pos", + cfg.attn_window, + latent_token_per_chunk, + action_token_per_chunk, + device=self.config.device, + dtype=self.dtype, + batch_size=2 if self._use_cfg else 1, + ) + self._cache_initialised = True + + def _repeat_input_for_cfg(self, input_dict): + if self._use_cfg: + input_dict["noisy_latents"] = input_dict["noisy_latents"].repeat(2, 1, 1, 1, 1) + input_dict["text_emb"] = torch.cat( + [ + self._prompt_embeds.to(self.dtype).clone(), + self._negative_prompt_embeds.to(self.dtype).clone(), + ], + dim=0, + ) + input_dict["grid_id"] = input_dict["grid_id"][None].repeat(2, 1, 1) + input_dict["timesteps"] = input_dict["timesteps"][None].repeat(2, 1) + else: + input_dict["grid_id"] = input_dict["grid_id"][None] + input_dict["timesteps"] = input_dict["timesteps"][None] + return input_dict + + def _prepare_latent_input( + self, + latent_model_input, + action_model_input, + latent_t=0, + action_t=0, + latent_cond=None, + action_cond=None, + frame_st_id=0, + ): + cfg = self.config + device = self.config.device + p = cfg.patch_size + out = {} + if latent_model_input is not None: + out["latent_res_lst"] = { + "noisy_latents": latent_model_input, + "timesteps": torch.ones([latent_model_input.shape[2]], dtype=torch.float32, device=device) + * latent_t, + "grid_id": get_mesh_id( + latent_model_input.shape[-3] // p[0], + latent_model_input.shape[-2] // p[1], + latent_model_input.shape[-1] // p[2], + 0, + 1, + frame_st_id, + ).to(device), + "text_emb": self._prompt_embeds.to(self.dtype).clone(), + } + if latent_cond is not None: + out["latent_res_lst"]["noisy_latents"][:, :, 0:1] = latent_cond[:, :, 0:1] + out["latent_res_lst"]["timesteps"][0:1] *= 0 + if action_model_input is not None: + out["action_res_lst"] = { + "noisy_latents": action_model_input, + "timesteps": torch.ones([action_model_input.shape[2]], dtype=torch.float32, device=device) + * action_t, + "grid_id": get_mesh_id( + action_model_input.shape[-3], + action_model_input.shape[-2], + action_model_input.shape[-1], + 1, + 1, + frame_st_id, + action=True, + ).to(device), + "text_emb": self._prompt_embeds.to(self.dtype).clone(), + } + if action_cond is not None: + out["action_res_lst"]["noisy_latents"][:, :, 0:1] = action_cond[:, :, 0:1] + out["action_res_lst"]["timesteps"][0:1] *= 0 + out["action_res_lst"]["noisy_latents"][:, ~self._action_mask] *= 0 + return out + + @property + def _action_mask(self): + mask = torch.zeros([self.config.action_dim], dtype=torch.bool) + mask[self.config.used_action_channel_ids] = True + return mask + + # ------------------------------------------------------------------ + # Action conditioning (executed action history) (de)normalization + # ------------------------------------------------------------------ + def _preprocess_action_state(self, action_norm: Tensor) -> Tensor: + """Build the action-conditioning tensor from the already-normalized executed actions. + + ``action_norm`` is the model-space action chunk ``[B, action_dim, F, action_per_frame, 1]``. + Upstream re-derives the conditioning from the raw executed action via quantile norm; here + the executed actions are already in the model-normalized space, so we pass them through. + """ + return action_norm.to(self.config.device, self.dtype) + + def _compute_kv_cache(self, obs_buffer, executed_actions): + """Feed real observed keyframes + executed actions back into the KV cache.""" + if not obs_buffer or executed_actions is None: + return + self.transformer.clear_pred_cache("pos") + # Concatenate the observed keyframe latents along the frame axis. + latent_model_input = torch.cat(obs_buffer, dim=2) + # On the first feedback, prepend the init latent so the latent/action frame counts align + # (upstream prepends ``init_latent`` to the observed keyframes when frame_st_id == 0). + if self._frame_st_id == 0 and getattr(self, "_init_latent", None) is not None: + latent_model_input = torch.cat([self._init_latent, latent_model_input], dim=2) + action_model_input = self._preprocess_action_state(executed_actions) + action_model_input = action_model_input.to(latent_model_input) + input_dict = self._prepare_latent_input( + latent_model_input, action_model_input, frame_st_id=self._frame_st_id + ) + with torch.no_grad(): + self.transformer( + self._repeat_input_for_cfg(input_dict["latent_res_lst"]), + update_cache=2, + cache_name="pos", + action_mode=False, + ) + self.transformer( + self._repeat_input_for_cfg(input_dict["action_res_lst"]), + update_cache=2, + cache_name="pos", + action_mode=True, + ) + self._frame_st_id += latent_model_input.shape[2] + + # ------------------------------------------------------------------ + # The core dual-stream denoising loop (one chunk) + # ------------------------------------------------------------------ + @torch.no_grad() + def _infer(self, init_latent, frame_st_id=0): + cfg = self.config + device = self.config.device + latent_h, latent_w = self._latent_hw + frame_chunk_size = cfg.frame_chunk_size + + latents = torch.randn(1, 48, frame_chunk_size, latent_h, latent_w, device=device, dtype=self.dtype) + actions = torch.randn( + 1, cfg.action_dim, frame_chunk_size, cfg.action_per_frame, 1, device=device, dtype=self.dtype + ) + + self._scheduler.set_timesteps(cfg.num_inference_steps) + self._action_scheduler.set_timesteps(cfg.action_num_inference_steps) + timesteps = F.pad(self._scheduler.timesteps, (0, 1), mode="constant", value=0) + if cfg.video_exec_step != -1: + timesteps = timesteps[: cfg.video_exec_step] + action_timesteps = F.pad(self._action_scheduler.timesteps, (0, 1), mode="constant", value=0) + + # 1. Video-latent denoising loop + for i, t in enumerate(timesteps): + last_step = i == len(timesteps) - 1 + latent_cond = ( + init_latent[:, :, 0:1].to(self.dtype) + if frame_st_id == 0 and init_latent is not None + else None + ) + input_dict = self._prepare_latent_input( + latents, None, t, t, latent_cond, None, frame_st_id=frame_st_id + ) + video_noise_pred = self.transformer( + self._repeat_input_for_cfg(input_dict["latent_res_lst"]), + update_cache=1 if last_step else 0, + cache_name="pos", + action_mode=False, + ) + if not last_step or cfg.video_exec_step != -1: + video_noise_pred = data_seq_to_patch( + cfg.patch_size, + video_noise_pred, + frame_chunk_size, + latent_h, + latent_w, + batch_size=2 if self._use_cfg else 1, + ) + if cfg.guidance_scale > 1: + video_noise_pred = video_noise_pred[1:] + cfg.guidance_scale * ( + video_noise_pred[:1] - video_noise_pred[1:] + ) + else: + video_noise_pred = video_noise_pred[:1] + latents = self._scheduler.step(video_noise_pred, t, latents, return_dict=False) + if frame_st_id == 0 and latent_cond is not None: + latents[:, :, 0:1] = latent_cond + + # 2. Action denoising loop + for i, t in enumerate(action_timesteps): + last_step = i == len(action_timesteps) - 1 + action_cond = ( + torch.zeros([1, cfg.action_dim, 1, cfg.action_per_frame, 1], device=device, dtype=self.dtype) + if frame_st_id == 0 + else None + ) + input_dict = self._prepare_latent_input( + None, actions, t, t, None, action_cond, frame_st_id=frame_st_id + ) + action_noise_pred = self.transformer( + self._repeat_input_for_cfg(input_dict["action_res_lst"]), + update_cache=1 if last_step else 0, + cache_name="pos", + action_mode=True, + ) + if not last_step: + from einops import rearrange + + action_noise_pred = rearrange(action_noise_pred, "b (f n) c -> b c f n 1", f=frame_chunk_size) + if cfg.action_guidance_scale > 1: + action_noise_pred = action_noise_pred[1:] + cfg.action_guidance_scale * ( + action_noise_pred[:1] - action_noise_pred[1:] + ) + else: + action_noise_pred = action_noise_pred[:1] + actions = self._action_scheduler.step(action_noise_pred, t, actions, return_dict=False) + if frame_st_id == 0 and action_cond is not None: + actions[:, :, 0:1] = action_cond + + actions[:, ~self._action_mask] *= 0 + return actions, latents + + # ------------------------------------------------------------------ + # Predicted-video decoding (opt-in) + # ------------------------------------------------------------------ + @torch.no_grad() + def _decode_predicted_video(self, latents) -> Tensor: + """VAE-decode predicted latents into a uint8 frame stack ``[T, H, W, 3]`` on CPU.""" + vae = self._vae + z_dim = vae.config.z_dim + latents = denormalize_latents( + latents.to(vae.dtype), vae.config.latents_mean, vae.config.latents_std, z_dim + ) + video = vae.decode(latents, return_dict=False)[0] # [B, C, F, H, W] in [-1, 1] + video = (video.float().clamp(-1, 1) + 1.0) / 2.0 + video = (video[0].permute(1, 2, 3, 0) * 255.0).round().to(torch.uint8) # [F, H, W, C] + return video.cpu() diff --git a/src/lerobot/policies/lingbot_va/processor_lingbot_va.py b/src/lerobot/policies/lingbot_va/processor_lingbot_va.py new file mode 100644 index 000000000..45b6b9077 --- /dev/null +++ b/src/lerobot/policies/lingbot_va/processor_lingbot_va.py @@ -0,0 +1,113 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pre/post-processor pipelines for the LingBot-VA policy. + +The policy itself handles image resizing, scaling to [-1, 1] and VAE encoding (the VAE +lives inside the policy), so the preprocessor only renames, batches, normalizes (IDENTITY) +and moves to device. The postprocessor reverses the *fixed* action quantile normalization +(``(action + 1) / 2 * (q99 - q01 + 1e-6) + q01``) baked into the released checkpoints — this +is a fixed transform, not a dataset-stats one, so it cannot use the standard +``UnnormalizerProcessorStep`` and is implemented as a dedicated step below. +""" + +from dataclasses import dataclass, field +from typing import Any + +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyActionProcessorStep, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import ( + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) + +from .configuration_lingbot_va import LingBotVAConfig + + +@dataclass +@ProcessorStepRegistry.register(name="lingbot_va_action_unnormalize") +class LingBotVAActionUnnormalizeStep(PolicyActionProcessorStep): + """Reverse LingBot-VA's fixed per-channel quantile normalization on predicted actions. + + The policy emits actions in the normalized ``[-1, 1]`` space of the used action channels. + This step maps them back to physical units via the fixed quantiles stored in the config. + """ + + action_q01: list[float] = field(default_factory=list) + action_q99: list[float] = field(default_factory=list) + + def action(self, action: PolicyAction) -> PolicyAction: + q01 = torch.as_tensor(self.action_q01, dtype=action.dtype, device=action.device) + q99 = torch.as_tensor(self.action_q99, dtype=action.dtype, device=action.device) + return (action + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01 + + def get_config(self) -> dict[str, Any]: + return {"action_q01": list(self.action_q01), "action_q99": list(self.action_q99)} + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +def make_lingbot_va_pre_post_processors( + config: LingBotVAConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Build the pre/post processor pipelines for LingBot-VA.""" + + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps: list[ProcessorStep] = [ + LingBotVAActionUnnormalizeStep(action_q01=config.action_q01, action_q99=config.action_q99), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/lingbot_va/schedulers.py b/src/lerobot/policies/lingbot_va/schedulers.py new file mode 100644 index 000000000..a3ab3bea0 --- /dev/null +++ b/src/lerobot/policies/lingbot_va/schedulers.py @@ -0,0 +1,155 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flow-matching scheduler for LingBot-VA. + +Vendored verbatim from the upstream LingBot-VA repository +(https://github.com/Robbyant/lingbot-va, ``wan_va/utils/scheduler.py``). LingBot-VA uses +two independent instances of this scheduler at inference time — one for the video-latent +stream and one for the action stream — each with its own ``shift`` (signal-to-noise ratio +shift) and number of denoising steps. +""" + +import math + +import torch + +__all__ = ["FlowMatchScheduler"] + + +class FlowMatchScheduler: + def __init__( + self, + num_inference_steps=100, + num_train_timesteps=1000, + shift=3.0, + sigma_max=1.0, + sigma_min=0.003 / 1.002, + inverse_timesteps=False, + extra_one_step=False, + reverse_sigmas=False, + exponential_shift=False, + exponential_shift_mu=None, + shift_terminal=None, + ): + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.inverse_timesteps = inverse_timesteps + self.extra_one_step = extra_one_step + self.reverse_sigmas = reverse_sigmas + self.exponential_shift = exponential_shift + self.exponential_shift_mu = exponential_shift_mu + self.shift_terminal = shift_terminal + self.set_timesteps(num_inference_steps) + + def set_timesteps( + self, + num_inference_steps=100, + denoising_strength=1.0, + training=False, + shift=None, + dynamic_shift_len=None, + ): + if shift is not None: + self.shift = shift + sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength + if self.extra_one_step: + self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] + else: + self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) + if self.inverse_timesteps: + self.sigmas = torch.flip(self.sigmas, dims=[0]) + if self.exponential_shift: + mu = ( + self.calculate_shift(dynamic_shift_len) + if dynamic_shift_len is not None + else self.exponential_shift_mu + ) + self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1)) + else: + self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) + if self.shift_terminal is not None: + one_minus_z = 1 - self.sigmas + scale_factor = one_minus_z[-1] / (1 - self.shift_terminal) + self.sigmas = 1 - (one_minus_z / scale_factor) + if self.reverse_sigmas: + self.sigmas = 1 - self.sigmas + self.timesteps = self.sigmas * self.num_train_timesteps + if training: + x = self.timesteps + y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) + self.linear_timesteps_weights = bsmntw_weighing + self.training = True + else: + self.training = False + + def step(self, model_output, timestep, sample, to_final=False, **kwargs): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + if to_final or timestep_id + 1 >= len(self.timesteps): + sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 + else: + sigma_ = self.sigmas[timestep_id + 1] + prev_sample = sample + model_output * (sigma_ - sigma) + return prev_sample + + def return_to_timestep(self, timestep, sample, sample_stablized): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + model_output = (sample - sample_stablized) / sigma + return model_output + + def add_noise(self, original_samples, noise, timestep, t_dim=2): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep = timestep[None] + timestep_id = torch.argmin((self.timesteps[:, None] - timestep).abs(), dim=0) + shape = [1] * noise.ndim + shape[t_dim] = timestep_id.shape[0] + sigma = self.sigmas[timestep_id].to(original_samples).view(shape) + sample = (1 - sigma) * original_samples + sigma * noise + return sample + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + def training_weight(self, timestep): + timestep_id = torch.argmin( + (self.timesteps[:, None].to(timestep.device) - timestep[None]).abs(), dim=0 + ) + weights = self.linear_timesteps_weights.to(timestep.device)[timestep_id].to(timestep.device) + return weights + + def calculate_shift( + self, + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 8192, + base_shift: float = 0.5, + max_shift: float = 0.9, + ): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu diff --git a/src/lerobot/policies/lingbot_va/wan_attention.py b/src/lerobot/policies/lingbot_va/wan_attention.py new file mode 100644 index 000000000..03cb93d4e --- /dev/null +++ b/src/lerobot/policies/lingbot_va/wan_attention.py @@ -0,0 +1,286 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Attention and rotary-position-embedding modules for the LingBot-VA Wan transformer. + +Vendored and lightly adapted from the upstream LingBot-VA repository +(https://github.com/Robbyant/lingbot-va, ``wan_va/modules/model.py``). The ``torch`` +SDPA backend is the default and is always available; the ``flashattn`` and ``flex`` +backends are imported lazily and only required when the corresponding ``attn_mode`` is +selected. State-dict parameter names are preserved verbatim so that conversion from the +original diffusers-style checkpoint is near-identity. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ``flash_attn`` and the flex-attention APIs are optional. We import them lazily inside the +# backends that need them so that the (default) ``torch`` SDPA path works on any platform, +# including CPU-only and macOS where neither package is available. + + +def custom_sdpa(q, k, v): + """Scaled-dot-product attention operating on ``(B, S, H, D)`` tensors.""" + out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)) + return out.transpose(1, 2) + + +def _load_flash_attn_func(): + try: + from flash_attn_interface import flash_attn_func + except ImportError: + try: + from flash_attn import flash_attn_func + except ImportError as e: + raise ImportError( + "attn_mode='flashattn' requires the `flash_attn` package, which is not installed. " + "Install it, or use attn_mode='torch' (the default)." + ) from e + return flash_attn_func + + +class WanRotaryPosEmbed(nn.Module): + """Rotary position embedding with separate frequency bases for frame / height / width.""" + + def __init__( + self, + attention_head_dim: int, + patch_size, + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + self.theta = theta + + self.f_dim = self.attention_head_dim - 2 * (self.attention_head_dim // 3) + self.h_dim = self.attention_head_dim // 3 + self.w_dim = self.attention_head_dim // 3 + + f_freqs_base, h_freqs_base, w_freqs_base = self._precompute_freqs_base() + self.f_freqs_base = f_freqs_base + self.h_freqs_base = h_freqs_base + self.w_freqs_base = w_freqs_base + + def _precompute_freqs_base(self): + # freqs_base = 1.0 / (theta ** (2k / dim)) + f_freqs_base = 1.0 / ( + self.theta ** (torch.arange(0, self.f_dim, 2)[: (self.f_dim // 2)].double() / self.f_dim) + ) + h_freqs_base = 1.0 / ( + self.theta ** (torch.arange(0, self.h_dim, 2)[: (self.h_dim // 2)].double() / self.h_dim) + ) + w_freqs_base = 1.0 / ( + self.theta ** (torch.arange(0, self.w_dim, 2)[: (self.w_dim // 2)].double() / self.w_dim) + ) + return f_freqs_base, h_freqs_base, w_freqs_base + + def forward(self, grid_ids): + with torch.no_grad(): + f_freqs = grid_ids[:, 0, :].unsqueeze(-1) * self.f_freqs_base.to(grid_ids.device) + h_freqs = grid_ids[:, 1, :].unsqueeze(-1) * self.h_freqs_base.to(grid_ids.device) + w_freqs = grid_ids[:, 2, :].unsqueeze(-1) * self.w_freqs_base.to(grid_ids.device) + freqs = torch.cat([f_freqs, h_freqs, w_freqs], dim=-1).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + + return freqs_cis + + +class WanAttention(nn.Module): + """Self/cross attention with KV-caching for autoregressive streaming inference. + + Backends: + * ``torch`` (default): standard SDPA, available everywhere. + * ``flashattn``: FlashAttention kernels (optional dependency). + * ``flex``: PyTorch flex-attention (optional, used for block-causal training masks). + """ + + def __init__( + self, + dim, + heads=8, + dim_head=64, + eps=1e-5, + dropout=0.0, + cross_attention_dim_head=None, + attn_mode="torch", + ): + super().__init__() + if attn_mode == "torch": + self.attn_op = custom_sdpa + elif attn_mode == "flashattn": + self.attn_op = _load_flash_attn_func() + elif attn_mode == "flex": + # Imported lazily to avoid a hard dependency on torch flex-attention at import time. + from .wan_flex_attention import FlexAttnFunc + + self.attn_op = FlexAttnFunc(cross_attention_dim_head is not None) + else: + raise ValueError( + f"Unsupported attention mode: {attn_mode}, only support 'torch', 'flashattn' and 'flex'" + ) + + self.inner_dim = dim_head * heads + self.heads = heads + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = ( + self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + ) + + self.to_q = nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = nn.ModuleList( + [ + nn.Linear(self.inner_dim, dim, bias=True), + nn.Dropout(dropout), + ] + ) + self.norm_q = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + # KV cache only lives on self-attention modules (cross_attention_dim_head is None). + self.attn_caches = {} if cross_attention_dim_head is None else None + + def clear_pred_cache(self, cache_name): + if self.attn_caches is None: + return + cache = self.attn_caches[cache_name] + is_pred = cache["is_pred"] + cache["mask"][is_pred] = False + + def clear_cache(self, cache_name): + if self.attn_caches is None: + return + self.attn_caches[cache_name] = None + + def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim, device, dtype, batch_size): + if self.attn_caches is None: + return + self.attn_caches[cache_name] = { + "k": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype), + "v": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype), + "id": torch.full((total_tolen,), -1, device=device), + "mask": torch.zeros((total_tolen,), dtype=torch.bool, device=device), + "is_pred": torch.zeros((total_tolen,), dtype=torch.bool, device=device), + } + + def allocate_slots(self, cache_name, key_size): + cache = self.attn_caches[cache_name] + mask = cache["mask"] + ids = cache["id"] + free = (~mask).nonzero(as_tuple=False).squeeze(-1) + + if free.numel() < key_size: + used = mask.nonzero(as_tuple=False).squeeze(-1) + + used_ids = ids[used] + order = torch.argsort(used_ids) + need = key_size - free.numel() + to_free = used[order[:need]] + + mask[to_free] = False + ids[to_free] = -1 + free = (~mask).nonzero(as_tuple=False).squeeze(-1) + + assert free.numel() >= key_size + return free[:key_size] + + def _next_cache_id(self, cache_name): + ids = self.attn_caches[cache_name]["id"] + mask = self.attn_caches[cache_name]["mask"] + + if mask.any(): + return ids[mask].max() + 1 + else: + return torch.tensor(0, device=ids.device, dtype=ids.dtype) + + def update_cache(self, cache_name, key, value, is_pred): + cache = self.attn_caches[cache_name] + + key_size = key.shape[1] + slots = self.allocate_slots(cache_name, key_size) + + new_id = self._next_cache_id(cache_name) + + cache["k"][:, slots] = key + cache["v"][:, slots] = value + cache["mask"][slots] = True + cache["id"][slots] = new_id + cache["is_pred"][slots] = is_pred + return slots + + def restore_cache(self, cache_name, slots): + self.attn_caches[cache_name]["mask"][slots] = False + + def forward( + self, + q, + k, + v, + rotary_emb, + update_cache=0, + cache_name="pos", + ): + kv_cache = ( + self.attn_caches[cache_name] + if (self.attn_caches is not None) and (cache_name in self.attn_caches) + else None + ) + + query, key, value = self.to_q(q), self.to_k(k), self.to_v(v) + query = self.norm_q(query) + query = query.unflatten(2, (self.heads, -1)) + key = self.norm_k(key) + key = key.unflatten(2, (self.heads, -1)) + value = value.unflatten(2, (self.heads, -1)) + if rotary_emb is not None: + + def apply_rotary_emb(x, freqs): + x_out = torch.view_as_complex( + x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) + ) + x_out = torch.view_as_real(x_out * freqs).flatten(3) + return x_out.to(x.dtype) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + slots = None + if kv_cache is not None and kv_cache["k"] is not None: + slots = self.update_cache(cache_name, key, value, is_pred=(update_cache == 1)) + key_pool = self.attn_caches[cache_name]["k"] + value_pool = self.attn_caches[cache_name]["v"] + mask = self.attn_caches[cache_name]["mask"] + valid = mask.nonzero(as_tuple=False).squeeze(-1) + key = key_pool[:, valid] + value = value_pool[:, valid] + + hidden_states = self.attn_op(query, key, value) + + if update_cache == 0: + if kv_cache is not None and kv_cache["k"] is not None: + self.restore_cache(cache_name, slots) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +__all__ = ["WanAttention", "WanRotaryPosEmbed", "custom_sdpa"] diff --git a/src/lerobot/policies/lingbot_va/wan_flex_attention.py b/src/lerobot/policies/lingbot_va/wan_flex_attention.py new file mode 100644 index 000000000..0b10dbf9b --- /dev/null +++ b/src/lerobot/policies/lingbot_va/wan_flex_attention.py @@ -0,0 +1,207 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flex-attention backend for the LingBot-VA Wan transformer (training only). + +This module is imported lazily and ONLY when ``attn_mode='flex'`` is requested. It builds +the block-causal / window / noise-vs-clean attention masks used during the dual-stream +flow-matching training described in the LingBot-VA paper. Inference uses the ``torch`` +SDPA backend (see :mod:`wan_attention`) which does not need flex-attention. + +``torch.nn.attention.flex_attention`` requires a recent PyTorch build with the relevant +inductor support; importing this module on an unsupported build raises ``ImportError``. +""" + +from collections.abc import Callable +from functools import partial +from typing import ClassVar + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.attention.flex_attention import ( + BlockMask, + and_masks, + create_block_mask, + flex_attention, + or_masks, +) + + +class FlexAttnFunc(nn.Module): + flex_attn: ClassVar[Callable] = torch.compile(flex_attention, dynamic=True) + compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) + attention_mask: ClassVar[BlockMask] = None + cross_attention_mask: ClassVar[BlockMask] = None + + def __init__(self, is_cross=False) -> None: + super().__init__() + self.is_cross = is_cross + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dtype=torch.bfloat16, + ) -> torch.Tensor: + q_varlen = rearrange(query[0], "s n d -> 1 n s d") + k_varlen = rearrange(key[0], "s n d -> 1 n s d") + v_varlen = rearrange(value[0], "s n d -> 1 n s d") + + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + q_varlen = half(q_varlen) + k_varlen = half(k_varlen) + v_varlen = half(v_varlen) + q_varlen = q_varlen.to(v_varlen.dtype) + k_varlen = k_varlen.to(v_varlen.dtype) + + block_mask = FlexAttnFunc.cross_attention_mask if self.is_cross else FlexAttnFunc.attention_mask + + x_out = FlexAttnFunc.flex_attn( + q_varlen, + k_varlen, + v_varlen, + block_mask=block_mask, + kernel_options={ + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 32, + }, + ) + + x_out = rearrange(x_out, "b n s d -> b s n d") + return x_out + + @staticmethod + @torch.no_grad() + def init_mask( + latent_shape, + action_shape, + padded_length, + chunk_size, + window_size, + patch_size, + device, + ): + torch._inductor.config.realize_opcount_threshold = 100 + B, _, L_F, L_H, L_W = latent_shape + _, _, A_F, A_H, A_W = action_shape + + latent_seq_id = ( + torch.arange(B)[:, None, None, None] + .expand(-1, L_F // patch_size[0], L_H // patch_size[1], L_W // patch_size[2]) + .flatten() + ) + action_seq_id = torch.arange(B)[:, None, None, None].expand(-1, A_F, A_H, A_W).flatten() + seq_ids = torch.cat([latent_seq_id] * 2 + [action_seq_id] * 2) + + latent_frame_id = ( + torch.arange(L_F)[None, :, None, None] + .expand(B, -1, L_H // patch_size[1], L_W // patch_size[2])[None] + .flatten() + ) + action_frame_id = torch.arange(A_F)[None, :, None, None].expand(B, -1, A_H, A_W)[None].flatten() + frame_ids = torch.cat( + [latent_frame_id // chunk_size * 2] * 2 + [action_frame_id // chunk_size * 2 + 1] * 2 + ) + + noise_ids = torch.cat( + [ + torch.zeros_like(latent_frame_id), + torch.ones_like(latent_frame_id), + torch.zeros_like(action_frame_id), + torch.ones_like(action_frame_id), + ] + ) + + seq_ids = F.pad(seq_ids, (0, padded_length), value=-1) + frame_ids = F.pad(frame_ids, (0, padded_length), value=-1) + noise_ids = F.pad(noise_ids, (0, padded_length), value=-1) + + mask_mod = FlexAttnFunc._get_mask_mod( + seq_ids.long().to(device), frame_ids.long().to(device), noise_ids.long().to(device), window_size + ) + block_mask = FlexAttnFunc.compiled_create_block_mask( + mask_mod, 1, 1, len(seq_ids), len(seq_ids), device=device, _compile=True + ) + FlexAttnFunc.attention_mask = block_mask + + text_seq_ids = torch.arange(B)[:, None].expand(-1, 512).flatten() + mask_mod_cross = FlexAttnFunc._get_cross_mask_mod( + seq_ids.long().to(device), text_seq_ids.long().to(device) + ) + block_mask_cross = FlexAttnFunc.compiled_create_block_mask( + mask_mod_cross, 1, 1, len(seq_ids), len(text_seq_ids), device=device, _compile=True + ) + FlexAttnFunc.cross_attention_mask = block_mask_cross + + @staticmethod + @torch.no_grad() + def _get_cross_mask_mod(seq_ids, text_seq_ids): + def seq_mask(b, h, q_idx, kv_idx): + return ( + (seq_ids[q_idx] == text_seq_ids[kv_idx]) & (seq_ids[q_idx] >= 0) & (text_seq_ids[kv_idx] >= 0) + ) + + return seq_mask + + @staticmethod + @torch.no_grad() + def _get_mask_mod(seq_ids, frame_ids, noise_ids, window_size): + def seq_mask(b, h, q_idx, kv_idx): + return (seq_ids[q_idx] == seq_ids[kv_idx]) & (seq_ids[q_idx] >= 0) & (seq_ids[kv_idx] >= 0) + + def block_causal_mask(b, h, q_idx, kv_idx): + return frame_ids[kv_idx] <= frame_ids[q_idx] + + def block_causal_mask_exclude_self(b, h, q_idx, kv_idx): + return frame_ids[kv_idx] < frame_ids[q_idx] + + def block_self_mask(b, h, q_idx, kv_idx): + return frame_ids[kv_idx] == frame_ids[q_idx] + + def clean2clean_mask(b, h, q_idx, kv_idx): + return (noise_ids[q_idx] == 1) & (noise_ids[kv_idx] == 1) + + def noise2clean_mask(b, h, q_idx, kv_idx): + return (noise_ids[q_idx] == 0) & (noise_ids[kv_idx] == 1) + + def noise2noise_mask(b, h, q_idx, kv_idx): + return (noise_ids[q_idx] == 0) & (noise_ids[kv_idx] == 0) + + def block_window_mask(b, h, q_idx, kv_idx, window_size: int): + return (frame_ids[q_idx] - frame_ids[kv_idx]).abs() <= window_size + + mask_list = [] + mask_list.append(and_masks(clean2clean_mask, block_causal_mask)) + mask_list.append(and_masks(noise2clean_mask, block_causal_mask_exclude_self)) + mask_list.append(and_masks(noise2noise_mask, block_self_mask)) + mask = or_masks(*mask_list) + mask = and_masks(mask, seq_mask) + mask = and_masks(mask, partial(block_window_mask, window_size=window_size)) + return mask + + +__all__ = ["FlexAttnFunc"] diff --git a/src/lerobot/policies/lingbot_va/wan_transformer.py b/src/lerobot/policies/lingbot_va/wan_transformer.py new file mode 100644 index 000000000..19f439c81 --- /dev/null +++ b/src/lerobot/policies/lingbot_va/wan_transformer.py @@ -0,0 +1,514 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The dual-stream Wan2.2 video-action transformer backbone for LingBot-VA. + +Vendored and lightly adapted from the upstream LingBot-VA repository +(https://github.com/Robbyant/lingbot-va, ``wan_va/modules/model.py``). + +The model keeps the diffusers ``ModelMixin``/``ConfigMixin`` mixins so the original +sharded ``transformer/`` checkpoint can be loaded with ``from_pretrained`` during +conversion, but in LeRobot it is owned as a plain ``nn.Module`` sub-component of +:class:`~lerobot.policies.lingbot_va.modeling_lingbot_va.LingBotVAPolicy`. State-dict +parameter names are preserved verbatim so conversion is near-identity. +""" + +import math +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import ( + PixArtAlphaTextProjection, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import FP32LayerNorm +from einops import rearrange + +from .wan_attention import WanAttention, WanRotaryPosEmbed + +__all__ = ["WanTransformer3DModel", "WanTransformerBlock", "WanTimeTextImageEmbedding"] + + +class WanTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim, + time_freq_dim, + time_proj_dim, + text_embed_dim, + pos_embed_seq_len, + ): + super().__init__() + + self.timesteps_proj = Timesteps( + num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + def forward(self, timestep: torch.Tensor, dtype=None): + B, L = timestep.shape + timestep = timestep.reshape(-1) + timestep = self.timesteps_proj(timestep) + time_embedder_dtype = self.time_embedder.linear_1.weight.dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).to(dtype=dtype) + timestep_proj = self.time_proj(self.act_fn(temb)) + return temb.reshape(B, L, -1), timestep_proj.reshape(B, L, -1) + + +class WanTransformerBlock(nn.Module): + def __init__( + self, + dim, + ffn_dim, + num_heads, + cross_attn_norm=False, + eps=1e-6, + attn_mode: str = "torch", + ): + super().__init__() + self.attn_mode = attn_mode + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + attn_mode=attn_mode, + ) + + # 2. Cross-attention + self.attn2 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=dim // num_heads, + attn_mode=attn_mode, + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states, + encoder_hidden_states, + temb, + rotary_emb, + update_cache=0, + cache_name="pos", + ) -> torch.Tensor: + temb_scale_shift_table = self.scale_shift_table[None] + temb.float() + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = rearrange( + temb_scale_shift_table, "b l n c -> b n l c" + ).chunk(6, dim=1) + shift_msa = shift_msa.squeeze(1) + scale_msa = scale_msa.squeeze(1) + gate_msa = gate_msa.squeeze(1) + c_shift_msa = c_shift_msa.squeeze(1) + c_scale_msa = c_scale_msa.squeeze(1) + c_gate_msa = c_gate_msa.squeeze(1) + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1.0 + scale_msa) + shift_msa).type_as( + hidden_states + ) + attn_output = self.attn1( + norm_hidden_states, + norm_hidden_states, + norm_hidden_states, + rotary_emb, + update_cache=update_cache, + cache_name=cache_name, + ) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states, + encoder_hidden_states, + None, + update_cache=0, + cache_name=cache_name, + ) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1.0 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + + ff_output = self.ffn(norm_hidden_states) + + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + return hidden_states + + +class WanTransformer3DModel(ModelMixin, ConfigMixin): + """Dual-stream (video + action) Wan2.2 DiT backbone with autoregressive KV caching.""" + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = [ + "patch_embedding_mlp", + "condition_embedder", + "condition_embedder_action", + "norm", + ] + _no_split_modules = ["WanTransformerBlock"] + _keep_in_fp32_modules = [ + "time_embedder", + "scale_shift_table", + "scale_shift_table_action", + "norm1", + "action_norm1", + "text_norm1", + "norm2", + "action_norm2", + "text_norm2", + "norm3", + "action_norm3", + "text_norm3", + ] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["WanTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size=(1, 2, 2), + num_attention_heads=24, + attention_head_dim=128, + in_channels=48, + out_channels=48, + action_dim=30, + text_dim=4096, + freq_dim=256, + ffn_dim=14336, + num_layers=30, + cross_attn_norm=True, + eps=1e-06, + rope_max_seq_len=1024, + pos_embed_seq_len=None, + attn_mode="torch", + ): + super().__init__() + self.patch_size = patch_size + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding_mlp = nn.Linear( + in_channels * patch_size[0] * patch_size[1] * patch_size[2], inner_dim + ) + self.action_embedder = nn.Linear(action_dim, inner_dim) + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + self.condition_embedder_action = deepcopy(self.condition_embedder) + + self.blocks = nn.ModuleList( + [ + WanTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps, attn_mode=attn_mode + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.action_proj_out = nn.Linear(inner_dim, action_dim) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + # ------------------------------------------------------------------ + # KV-cache management for autoregressive streaming inference + # ------------------------------------------------------------------ + def clear_cache(self, cache_name): + for block in self.blocks: + block.attn1.clear_cache(cache_name) + + def clear_pred_cache(self, cache_name): + for block in self.blocks: + block.attn1.clear_pred_cache(cache_name) + + def create_empty_cache( + self, + cache_name, + attn_window, + latent_token_per_chunk, + action_token_per_chunk, + device, + dtype, + batch_size, + ): + total_tolen = (attn_window // 2) * latent_token_per_chunk + ( + attn_window // 2 + ) * action_token_per_chunk + for block in self.blocks: + block.attn1.init_kv_cache( + cache_name, + total_tolen, + self.num_attention_heads, + self.attention_head_dim, + device, + dtype, + batch_size, + ) + + # ------------------------------------------------------------------ + # Embedding helpers (shared by train + inference paths) + # ------------------------------------------------------------------ + def _input_embed(self, latents, input_type="latent"): + if input_type == "latent": + hidden_states = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self.patch_size[0], + p2=self.patch_size[1], + p3=self.patch_size[2], + ) + hidden_states = self.patch_embedding_mlp(hidden_states) + elif input_type == "action": + hidden_states = rearrange(latents, "b c f h w -> b (f h w) c") + hidden_states = self.action_embedder(hidden_states) + elif input_type == "text": + hidden_states = self.condition_embedder.text_embedder(latents) + else: + raise ValueError(f"Unsupported input type: {input_type}") + return hidden_states + + def _time_embed(self, timesteps, H, W, dtype, action_mode=False): + pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2]) + latent_time_steps = torch.repeat_interleave( + timesteps, (H // pach_scale_h) * (W // pach_scale_w), dim=1 + ) + current_condition_embedder = ( + self.condition_embedder_action if action_mode else self.condition_embedder + ) + temb, timestep_proj = current_condition_embedder(latent_time_steps, dtype=dtype) + timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C + return temb, timestep_proj + + # ------------------------------------------------------------------ + # Dual-stream training forward (flow matching). Requires attn_mode='flex'. + # ------------------------------------------------------------------ + def forward_train(self, input_dict): + from .wan_flex_attention import FlexAttnFunc + + input_dict["latent_dict"]["noisy_latents"] = input_dict["latent_dict"]["noisy_latents"].to( + torch.bfloat16 + ) + input_dict["latent_dict"]["latent"] = input_dict["latent_dict"]["latent"].to(torch.bfloat16) + input_dict["action_dict"]["noisy_latents"] = input_dict["action_dict"]["noisy_latents"].to( + torch.bfloat16 + ) + input_dict["action_dict"]["latent"] = input_dict["action_dict"]["latent"].to(torch.bfloat16) + + latent_dict = input_dict["latent_dict"] + action_dict = input_dict["action_dict"] + batch_size = latent_dict["noisy_latents"].shape[0] + + latent_hidden_states = self._input_embed(latent_dict["noisy_latents"], input_type="latent").flatten( + 0, 1 + )[None] + action_hidden_states = self._input_embed(action_dict["noisy_latents"], input_type="action").flatten( + 0, 1 + )[None] + text_hidden_states = self._input_embed(latent_dict["text_emb"], input_type="text") + + text_hidden_states = text_hidden_states.flatten(0, 1)[None] + + condition_latent_hidden_states = self._input_embed( + latent_dict["latent"], input_type="latent" + ).flatten(0, 1)[None] + condition_action_hidden_states = self._input_embed( + action_dict["latent"], input_type="action" + ).flatten(0, 1)[None] + + hidden_states = torch.cat( + [ + latent_hidden_states, + condition_latent_hidden_states, + action_hidden_states, + condition_action_hidden_states, + ], + dim=1, + ) + + latent_grid_id = latent_dict["grid_id"].permute(1, 0, 2).flatten(1)[None] + action_grid_id = action_dict["grid_id"].permute(1, 0, 2).flatten(1)[None] + full_grid_id = torch.cat([latent_grid_id] * 2 + [action_grid_id] * 2, dim=2) + + rotary_emb = self.rope(full_grid_id)[:, :, None] + + latent_time_steps = torch.cat( + [latent_dict["timesteps"].flatten(0, 1), latent_dict["cond_timesteps"].flatten(0, 1)] + )[None] + action_time_steps = torch.cat( + [action_dict["timesteps"].flatten(0, 1), action_dict["cond_timesteps"].flatten(0, 1)] + )[None] + latent_temb, latent_timestep_proj = self._time_embed( + latent_time_steps, + latent_dict["noisy_latents"].shape[-2], + latent_dict["noisy_latents"].shape[-1], + dtype=hidden_states.dtype, + action_mode=False, + ) + action_temb, action_timestep_proj = self._time_embed( + action_time_steps, + action_dict["noisy_latents"].shape[-2], + action_dict["noisy_latents"].shape[-1], + dtype=hidden_states.dtype, + action_mode=True, + ) + temb = torch.cat([latent_temb, action_temb], dim=1) + timestep_proj = torch.cat([latent_timestep_proj, action_timestep_proj], dim=1) + + total_length = hidden_states.shape[1] + padded_length = (128 - total_length % 128) % 128 + hidden_states = F.pad(hidden_states, (0, 0, 0, padded_length)) + rotary_emb = F.pad(rotary_emb, (0, 0, 0, 0, 0, padded_length)) + temb = F.pad(temb, (0, 0, 0, padded_length)) + timestep_proj = F.pad(timestep_proj, (0, 0, 0, 0, 0, padded_length)) + + split_list = [ + latent_hidden_states.shape[1], + condition_latent_hidden_states.shape[1], + action_hidden_states.shape[1], + condition_action_hidden_states.shape[1], + padded_length, + ] + + FlexAttnFunc.init_mask( + latent_dict["noisy_latents"].shape, + action_dict["noisy_latents"].shape, + padded_length, + input_dict["chunk_size"], + window_size=input_dict["window_size"], + patch_size=self.patch_size, + device=hidden_states.device, + ) + + for block in self.blocks: + hidden_states = block( + hidden_states, text_hidden_states, timestep_proj, rotary_emb, update_cache=False + ) + temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...] + shift, scale = rearrange(temb_scale_shift_table, "b l n c -> b n l c").chunk(2, dim=1) + shift = shift.to(hidden_states.device).squeeze(1) + scale = scale.to(hidden_states.device).squeeze(1) + hidden_states = (self.norm_out(hidden_states.float()) * (1.0 + scale) + shift).type_as(hidden_states) + latent_hidden_states, _, action_hidden_states, _, _ = torch.split(hidden_states, split_list, dim=1) + latent_hidden_states = self.proj_out(latent_hidden_states) + latent_hidden_states = rearrange( + latent_hidden_states, "1 (b l) (n c) -> b (l n) c", n=math.prod(self.patch_size), b=batch_size + ) + action_hidden_states = self.action_proj_out(action_hidden_states) + action_hidden_states = rearrange(action_hidden_states, "1 (b l) c -> b l c", b=batch_size) + + return latent_hidden_states, action_hidden_states + + # ------------------------------------------------------------------ + # Single-stream inference forward (one denoising step for one stream) + # ------------------------------------------------------------------ + def forward( + self, + input_dict, + update_cache=0, + cache_name="pos", + action_mode=False, + train_mode=False, + ): + if train_mode: + return self.forward_train(input_dict) + if action_mode: # action input emb + latent_hidden_states = rearrange(input_dict["noisy_latents"], "b c f h w -> b (f h w) c") + latent_hidden_states = self.action_embedder(latent_hidden_states) # B L1 C + else: # latent input emb + latent_hidden_states = rearrange( + input_dict["noisy_latents"], + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self.patch_size[0], + p2=self.patch_size[1], + p3=self.patch_size[2], + ) + latent_hidden_states = self.patch_embedding_mlp(latent_hidden_states) + text_hidden_states = self.condition_embedder.text_embedder(input_dict["text_emb"]) # B L2 C + + latent_grid_id = input_dict["grid_id"] + rotary_emb = self.rope(latent_grid_id)[:, :, None] # 1 L 1 C + pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2]) + + latent_time_steps = torch.repeat_interleave( + input_dict["timesteps"], + (input_dict["noisy_latents"].shape[-2] // pach_scale_h) + * (input_dict["noisy_latents"].shape[-1] // pach_scale_w), + dim=1, + ) # L + current_condition_embedder = ( + self.condition_embedder_action if action_mode else self.condition_embedder + ) + temb, timestep_proj = current_condition_embedder(latent_time_steps, dtype=latent_hidden_states.dtype) + timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C + + for block in self.blocks: + latent_hidden_states = block( + latent_hidden_states, + text_hidden_states, + timestep_proj, + rotary_emb, + update_cache=update_cache, + cache_name=cache_name, + ) + temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...] + shift, scale = rearrange(temb_scale_shift_table, "b l n c -> b n l c").chunk(2, dim=1) + shift = shift.to(latent_hidden_states.device).squeeze(1) + scale = scale.to(latent_hidden_states.device).squeeze(1) + latent_hidden_states = (self.norm_out(latent_hidden_states.float()) * (1.0 + scale) + shift).type_as( + latent_hidden_states + ) + + if action_mode: + latent_hidden_states = self.action_proj_out(latent_hidden_states) + else: + latent_hidden_states = self.proj_out(latent_hidden_states) + latent_hidden_states = rearrange( + latent_hidden_states, "b l (n c) -> b (l n) c", n=math.prod(self.patch_size) + ) + + return latent_hidden_states diff --git a/src/lerobot/policies/lingbot_va/wan_utils.py b/src/lerobot/policies/lingbot_va/wan_utils.py new file mode 100644 index 000000000..9292b519a --- /dev/null +++ b/src/lerobot/policies/lingbot_va/wan_utils.py @@ -0,0 +1,56 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid-id / patch utilities for the LingBot-VA autoregressive inference loop. + +Vendored verbatim from the upstream LingBot-VA repository +(https://github.com/Robbyant/lingbot-va, ``wan_va/utils/utils.py``). +""" + +import torch + +__all__ = ["get_mesh_id", "data_seq_to_patch"] + + +def data_seq_to_patch(patch_size, data_seq, latent_num_frames, latent_height, latent_width, batch_size=1): + """Reshape a flattened patch sequence back into a ``(B, C, F, H, W)`` latent grid.""" + p_t, p_h, p_w = patch_size + post_patch_num_frames = latent_num_frames // p_t + post_patch_height = latent_height // p_h + post_patch_width = latent_width // p_w + + data_patch = data_seq.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + data_patch = data_patch.permute(0, 7, 1, 4, 2, 5, 3, 6) + data_patch = data_patch.flatten(6, 7).flatten(4, 5).flatten(2, 3) + return data_patch + + +def get_mesh_id(f, h, w, t, f_w=1, f_shift=0, action=False): + """Build the (frame, height, width, stream) grid ids used to index the rotary embedding.""" + f_idx = torch.arange(f_shift, f + f_shift) * f_w + h_idx = torch.arange(h) + w_idx = torch.arange(w) + ff, hh, ww = torch.meshgrid(f_idx, h_idx, w_idx, indexing="ij") + if action: + ff_offset = (torch.ones([h]).cumsum(0) / (h + 1)).view(1, -1, 1) + ff = ff + ff_offset + hh = torch.ones_like(hh) * -1 + ww = torch.ones_like(ww) * -1 + + grid_id = torch.cat([ff.unsqueeze(0), hh.unsqueeze(0), ww.unsqueeze(0)], dim=0).flatten(1) + grid_id = torch.cat([grid_id, torch.full_like(grid_id[:1], t)], dim=0) + return grid_id diff --git a/src/lerobot/policies/lingbot_va/wan_vae.py b/src/lerobot/policies/lingbot_va/wan_vae.py new file mode 100644 index 000000000..c1fff4886 --- /dev/null +++ b/src/lerobot/policies/lingbot_va/wan_vae.py @@ -0,0 +1,120 @@ +# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thin helpers around the stock diffusers ``AutoencoderKLWan`` (Wan2.2, ``z_dim=48``). + +The VAE class itself is NOT vendored — it lives in ``diffusers>=0.36``. This module +provides: + * loaders for the VAE / text encoder / tokenizer / transformer sub-checkpoints, + * the streaming-encoder wrapper used for autoregressive frame-by-frame VAE encoding + (it caches the causal-conv state across chunks), + * latent (de)normalization helpers using the VAE's ``latents_mean`` / ``latents_std``. + +Vendored and adapted from ``wan_va/modules/utils.py`` upstream. +""" + +import torch + +__all__ = [ + "WanVAEStreamingWrapper", + "load_vae", + "load_text_encoder", + "load_tokenizer", + "normalize_latents", + "denormalize_latents", + "patchify", +] + + +def load_vae(vae_path, torch_dtype, torch_device): + from diffusers import AutoencoderKLWan + + vae = AutoencoderKLWan.from_pretrained(vae_path, torch_dtype=torch_dtype) + return vae.to(torch_device) + + +def load_text_encoder(text_encoder_path, torch_dtype, torch_device): + from transformers import UMT5EncoderModel + + text_encoder = UMT5EncoderModel.from_pretrained(text_encoder_path, torch_dtype=torch_dtype) + return text_encoder.to(torch_device) + + +def load_tokenizer(tokenizer_path): + from transformers import T5TokenizerFast + + return T5TokenizerFast.from_pretrained(tokenizer_path) + + +def patchify(x, patch_size): + if patch_size is None or patch_size == 1: + return x + batch_size, channels, frames, height, width = x.shape + x = x.view( + batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size + ) + x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous() + x = x.view( + batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size + ) + return x + + +def normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor +) -> torch.Tensor: + """Apply ``(x - mean) * std`` channel-wise (note: upstream passes ``1/std`` as ``latents_std``).""" + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) + latents = ((latents.float() - latents_mean) * latents_std).to(latents) + return latents + + +def denormalize_latents(latents: torch.Tensor, latents_mean, latents_std, z_dim) -> torch.Tensor: + """Inverse of the normalization applied at encode time, for VAE decoding of predicted latents.""" + mean = torch.tensor(latents_mean).view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) + inv_std = 1.0 / torch.tensor(latents_std).view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) + return latents / inv_std + mean + + +class WanVAEStreamingWrapper: + """Wraps an ``AutoencoderKLWan`` encoder to support causal streaming encoding across chunks.""" + + def __init__(self, vae_model): + self.vae = vae_model + self.encoder = vae_model.encoder + self.quant_conv = vae_model.quant_conv + + if hasattr(self.vae, "_cached_conv_counts"): + self.enc_conv_num = self.vae._cached_conv_counts["encoder"] + else: + count = 0 + for m in self.encoder.modules(): + if m.__class__.__name__ == "WanCausalConv3d": + count += 1 + self.enc_conv_num = count + + self.clear_cache() + + def clear_cache(self): + self.feat_cache = [None] * self.enc_conv_num + + def encode_chunk(self, x_chunk): + if hasattr(self.vae.config, "patch_size") and self.vae.config.patch_size is not None: + x_chunk = patchify(x_chunk, self.vae.config.patch_size) + feat_idx = [0] + out = self.encoder(x_chunk, feat_cache=self.feat_cache, feat_idx=feat_idx) + enc = self.quant_conv(out) + return enc diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index d45483d21..2f16515b2 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -105,6 +105,7 @@ def rollout( seeds: list[int] | None = None, return_observations: bool = False, render_callback: Callable[[gym.vector.VectorEnv], None] | None = None, + predicted_frames_callback: Callable[[PreTrainedPolicy], None] | None = None, ) -> dict: """Run a batched policy rollout once through a batch of environments. @@ -134,6 +135,9 @@ def rollout( are returned optionally because they typically take more memory to cache. Defaults to False. render_callback: Optional rendering callback to be used after the environments are reset, and after every step. + predicted_frames_callback: Optional callback invoked after every ``select_action`` with the policy + itself. World-model policies (e.g. LingBot-VA) stash their decoded predicted video frames on + ``policy.last_predicted_frames``; this lets the caller collect them to save predicted-video MP4s. Returns: The dictionary described above. """ @@ -184,6 +188,8 @@ def rollout( observation = preprocessor(observation) with torch.inference_mode(): action = policy.select_action(observation) + if predicted_frames_callback is not None: + predicted_frames_callback(policy) action = postprocessor(action) action_transition = {ACTION: action} @@ -273,6 +279,7 @@ def eval_policy( videos_dir: Path | None = None, return_episode_data: bool = False, start_seed: int | None = None, + save_predicted_video: bool = False, ) -> dict: """ Args: @@ -291,6 +298,11 @@ def eval_policy( if max_episodes_rendered > 0 and not videos_dir: raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.") + # World-model policies (e.g. LingBot-VA) opt into predicted-video saving via their config. + save_predicted_video = save_predicted_video or bool( + getattr(getattr(policy, "config", None), "save_predicted_video", False) + ) + if not isinstance(policy, PreTrainedPolicy): exc = ValueError( f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided." @@ -334,6 +346,21 @@ def eval_policy( if max_episodes_rendered > 0: video_paths: list[str] = [] + if save_predicted_video: + if not videos_dir: + raise ValueError("If save_predicted_video is True, videos_dir must be provided.") + predicted_video_paths: list[str] = [] + n_predicted_rendered = 0 + + # Collects the policy's decoded predicted-video frames across a rollout (world-model policies only). + def collect_predicted_frames(policy: PreTrainedPolicy): + frames = getattr(policy, "last_predicted_frames", None) + if frames is not None: + pred_frames.append( + np.asarray(frames.detach().to("cpu")) if hasattr(frames, "detach") else np.asarray(frames) + ) + policy.last_predicted_frames = None + if return_episode_data: episode_data: dict | None = None @@ -345,6 +372,9 @@ def eval_policy( if max_episodes_rendered > 0: ep_frames: list[np.ndarray] = [] + if save_predicted_video: + pred_frames: list[np.ndarray] = [] + if start_seed is None: seeds = None else: @@ -361,6 +391,7 @@ def eval_policy( seeds=list(seeds) if seeds else None, return_observations=return_episode_data, render_callback=render_frame if max_episodes_rendered > 0 else None, + predicted_frames_callback=collect_predicted_frames if save_predicted_video else None, ) # Figure out where in each rollout sequence the first done condition was encountered (results after @@ -426,6 +457,25 @@ def eval_policy( threads.append(thread) n_episodes_rendered += 1 + # Maybe save the policy's predicted (imagined) video for this batch's rollout. + if save_predicted_video and len(pred_frames) > 0: + # pred_frames is a list of [F, H, W, C] uint8 stacks emitted on chunk refills; concat over time. + predicted_video = np.concatenate(pred_frames, axis=0) + videos_dir.mkdir(parents=True, exist_ok=True) + predicted_video_path = videos_dir / f"pred_episode_{n_predicted_rendered}.mp4" + predicted_video_paths.append(str(predicted_video_path)) + thread = threading.Thread( + target=write_video, + args=( + str(predicted_video_path), + predicted_video, + env.unwrapped.metadata["render_fps"], + ), + ) + thread.start() + threads.append(thread) + n_predicted_rendered += 1 + progbar.set_postfix( {"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"} ) @@ -469,6 +519,9 @@ def eval_policy( if max_episodes_rendered > 0: info["video_paths"] = video_paths + if save_predicted_video: + info["predicted_video_paths"] = predicted_video_paths + return info diff --git a/tests/policies/lingbot_va/test_configuration.py b/tests/policies/lingbot_va/test_configuration.py new file mode 100644 index 000000000..5eb77dd1e --- /dev/null +++ b/tests/policies/lingbot_va/test_configuration.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES + + +def make_config(**overrides) -> LingBotVAConfig: + kwargs = {"device": "cpu"} + kwargs.update(overrides) + return LingBotVAConfig(**kwargs) + + +def test_registered_in_choice_registry() -> None: + assert "lingbot_va" in PreTrainedConfig.get_known_choices() + assert PreTrainedConfig.get_choice_class("lingbot_va") is LingBotVAConfig + + +def test_type_property() -> None: + assert make_config().type == "lingbot_va" + + +def test_chunk_size_and_action_steps() -> None: + cfg = make_config(frame_chunk_size=4, action_per_frame=4) + assert cfg.chunk_size == 16 + assert cfg.n_action_steps == 16 + assert cfg.action_delta_indices == list(range(16)) + assert cfg.observation_delta_indices is None + assert cfg.reward_delta_indices is None + + +def test_optimizer_and_scheduler_presets() -> None: + cfg = make_config() + opt = cfg.get_optimizer_preset() + assert opt.lr == cfg.optimizer_lr + sched = cfg.get_scheduler_preset() + assert sched.num_warmup_steps == cfg.scheduler_warmup_steps + + +def test_validate_features_sets_action_feature() -> None: + cfg = make_config() + cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))} + cfg.output_features = {} + cfg.validate_features() + assert ACTION in cfg.output_features + assert cfg.output_features[ACTION].shape == (len(cfg.used_action_channel_ids),) + + +def test_validate_features_no_visual_raises() -> None: + cfg = make_config() + cfg.input_features = {} + cfg.output_features = {} + with pytest.raises(ValueError, match="at least one visual input feature"): + cfg.validate_features() + + +def test_invalid_attn_mode_raises() -> None: + with pytest.raises(ValueError, match="attn_mode"): + make_config(attn_mode="banana") + + +def test_quantile_length_mismatch_raises() -> None: + with pytest.raises(ValueError, match="action_q01"): + make_config(used_action_channel_ids=[0, 1, 2], action_q01=[0.0, 0.0], action_q99=[1.0, 1.0, 1.0]) diff --git a/tests/policies/lingbot_va/test_factory.py b/tests/policies/lingbot_va/test_factory.py new file mode 100644 index 000000000..4b96008c7 --- /dev/null +++ b/tests/policies/lingbot_va/test_factory.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest + +from lerobot.policies.factory import make_policy_config +from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig + + +def test_make_policy_config_returns_lingbot_va() -> None: + cfg = make_policy_config("lingbot_va", device="cpu") + assert isinstance(cfg, LingBotVAConfig) + + +def test_get_policy_class_resolves_lazily() -> None: + # Importing the policy class pulls in diffusers (Wan2.2 stack); skip if unavailable. + pytest.importorskip("diffusers") + pytest.importorskip("transformers") + from lerobot.policies.factory import get_policy_class + + cls = get_policy_class("lingbot_va") + assert cls.name == "lingbot_va" + assert cls.config_class is LingBotVAConfig + + +def test_convert_build_config_libero() -> None: + pytest.importorskip("diffusers") + from lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints import build_config + + cfg = build_config("libero", wan_pretrained_path="dummy/path", dtype="float32") + assert cfg.height == 128 and cfg.width == 128 + assert cfg.used_action_channel_ids == list(range(7)) + # validate_features (called inside build_config) must have populated the action feature. + from lerobot.utils.constants import ACTION + + assert cfg.output_features[ACTION].shape == (7,) + assert len(cfg.obs_cam_keys) == 2 diff --git a/tests/policies/lingbot_va/test_modules.py b/tests/policies/lingbot_va/test_modules.py new file mode 100644 index 000000000..048bd57b3 --- /dev/null +++ b/tests/policies/lingbot_va/test_modules.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure-torch unit tests for the vendored LingBot-VA helper modules (no diffusers needed).""" + +from __future__ import annotations + +import torch + +from lerobot.policies.lingbot_va.schedulers import FlowMatchScheduler +from lerobot.policies.lingbot_va.wan_utils import data_seq_to_patch, get_mesh_id + + +def test_flow_match_scheduler_timesteps_monotone_decreasing() -> None: + sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True) + sch.set_timesteps(20) + assert sch.timesteps.shape == (20,) + diffs = sch.timesteps[1:] - sch.timesteps[:-1] + assert torch.all(diffs <= 0) # decreasing + + +def test_flow_match_scheduler_step_preserves_shape() -> None: + sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True) + sch.set_timesteps(20) + sample = torch.zeros(1, 48, 4, 8, 16) + out = sch.step(torch.ones_like(sample), sch.timesteps[0], sample) + assert out.shape == sample.shape + + +def test_flow_match_scheduler_add_noise() -> None: + sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True) + sch.set_timesteps(20) + sample = torch.randn(1, 48, 4, 8, 16) + noise = torch.randn_like(sample) + noisy = sch.add_noise(sample, noise, sch.timesteps[:4], t_dim=2) + assert noisy.shape == sample.shape + + +def test_get_mesh_id_latent_shape() -> None: + grid = get_mesh_id(4, 8, 16, 0, 1, 0) + assert grid.shape == (4, 4 * 8 * 16) # (f, h, w, stream) x tokens + + +def test_get_mesh_id_action_shape() -> None: + grid = get_mesh_id(4, 4, 1, 1, 1, 0, action=True) + assert grid.shape == (4, 4 * 4 * 1) + # Action rows for h/w are sentinel -1. + assert torch.all(grid[1] < 0) + assert torch.all(grid[2] < 0) + + +def test_data_seq_to_patch_roundtrip_shape() -> None: + b, f, h, w, c = 1, 4, 8, 16, 48 + seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c) + out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b) + assert out.shape == (b, c, f, h, w) diff --git a/tests/policies/lingbot_va/test_processor.py b/tests/policies/lingbot_va/test_processor.py new file mode 100644 index 000000000..4bf0b493c --- /dev/null +++ b/tests/policies/lingbot_va/test_processor.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig +from lerobot.policies.lingbot_va.processor_lingbot_va import ( + LingBotVAActionUnnormalizeStep, + make_lingbot_va_pre_post_processors, +) +from lerobot.utils.constants import ( + OBS_IMAGES, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) + + +def _make_config() -> LingBotVAConfig: + cfg = LingBotVAConfig(device="cpu") + cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))} + cfg.output_features = {} + cfg.validate_features() + return cfg + + +def test_action_unnormalize_inverts_quantile_norm() -> None: + q01 = [-1.0, -0.5, 0.0] + q99 = [1.0, 0.5, 2.0] + step = LingBotVAActionUnnormalizeStep(action_q01=q01, action_q99=q99) + + # Forward (the policy-side) quantile normalization: (x - q01) / (q99 - q01 + eps) * 2 - 1. + q01_t = torch.tensor(q01) + q99_t = torch.tensor(q99) + raw = torch.tensor([[0.3, 0.1, 1.0]]) + normed = (raw - q01_t) / (q99_t - q01_t + 1e-6) * 2.0 - 1.0 + + recovered = step.action(normed) + assert torch.allclose(recovered, raw, atol=1e-4) + + +def test_action_unnormalize_config_roundtrip() -> None: + step = LingBotVAActionUnnormalizeStep(action_q01=[0.0, 1.0], action_q99=[2.0, 3.0]) + cfg = step.get_config() + assert cfg == {"action_q01": [0.0, 1.0], "action_q99": [2.0, 3.0]} + rebuilt = LingBotVAActionUnnormalizeStep(**cfg) + assert rebuilt.action_q01 == step.action_q01 + assert rebuilt.action_q99 == step.action_q99 + + +def test_make_pre_post_processors_names_and_steps() -> None: + cfg = _make_config() + pre, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None) + assert pre.name == POLICY_PREPROCESSOR_DEFAULT_NAME + assert post.name == POLICY_POSTPROCESSOR_DEFAULT_NAME + # The postprocessor must contain the dedicated quantile unnormalize step. + assert any(isinstance(s, LingBotVAActionUnnormalizeStep) for s in post.steps) + + +def test_postprocessor_applies_unnormalization() -> None: + cfg = _make_config() + _, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None) + # A normalized action of all -1 should map back to q01. + normed = torch.full((1, len(cfg.used_action_channel_ids)), -1.0) + out = post(normed) + assert torch.allclose(out, torch.tensor(cfg.action_q01).unsqueeze(0), atol=1e-4)