make tokenizer/text-encoder model ids configurable + some nits

This commit is contained in:
Maxime Ellerbach
2026-06-23 09:07:16 +00:00
parent 5f6ddab629
commit 343ab5f99c
5 changed files with 48 additions and 25 deletions
@@ -28,7 +28,9 @@ from lerobot.optim import AdamWConfig
from lerobot.utils.constants import ACTION, OBS_STATE
WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B"
WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
FASTWAM_BASE_MODEL_ID = "lerobot/fastwam_base"
WAN_T5_TOKENIZER_ID = "google/umt5-xxl"
_FASTWAM_VIDEO_BASE_COMPAT_KEYS = (
@@ -99,8 +101,11 @@ def _coerce_enum(enum_cls: type, value: Any) -> Any:
return value
try:
return enum_cls(value)
except (TypeError, ValueError):
return getattr(enum_cls, str(value), value)
except (TypeError, ValueError) as exc:
member = getattr(enum_cls, str(value), None)
if member is None:
raise ValueError(f"Cannot coerce {value!r} into {enum_cls.__name__}.") from exc
return member
def _coerce_policy_features(features: dict[str, Any] | None) -> dict[str, PolicyFeature] | None:
@@ -184,7 +189,8 @@ class FastWAMConfig(PreTrainedConfig):
image_size: tuple[int, int] = (224, 448)
context_len: int = 128
model_id: str = WAN22_MODEL_ID
tokenizer_model_id: str = WAN22_MODEL_ID
tokenizer_model_id: str = WAN_T5_TOKENIZER_ID
text_encoder_model_id: str = WAN22_DIFFUSERS_MODEL_ID
base_model_id: str | None = FASTWAM_BASE_MODEL_ID
tokenizer_max_len: int = 128
load_text_encoder: bool = True
@@ -229,7 +235,6 @@ class FastWAMConfig(PreTrainedConfig):
super().__post_init__()
self.image_size = tuple(self.image_size)
self.model_id = _validate_wan_model_id(self.model_id, "model_id")
self.tokenizer_model_id = _validate_wan_model_id(self.tokenizer_model_id, "tokenizer_model_id")
self.input_features = _coerce_policy_features(self.input_features)
self.output_features = _coerce_policy_features(self.output_features)
self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions]
@@ -253,7 +253,9 @@ class FastWAMPolicy(PreTrainedPolicy):
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
)
text_encoder = (
load_pretrained_wan_text_encoder(torch_dtype=dtype, device=device)
load_pretrained_wan_text_encoder(
model_id=config.text_encoder_model_id, torch_dtype=dtype, device=device
)
if config.load_text_encoder
else None
)
@@ -263,7 +265,9 @@ class FastWAMPolicy(PreTrainedPolicy):
mot=mot,
vae=load_pretrained_wan_vae(torch_dtype=dtype, device=device),
text_encoder=text_encoder,
tokenizer=build_wan_tokenizer(tokenizer_max_len=config.tokenizer_max_len),
tokenizer=build_wan_tokenizer(
model_id=config.tokenizer_model_id, tokenizer_max_len=config.tokenizer_max_len
),
text_dim=int(config.video_dit_config["text_dim"]),
proprio_dim=config.proprio_dim,
device=device,
@@ -279,6 +283,11 @@ class FastWAMPolicy(PreTrainedPolicy):
)
def _scalar(value: Any) -> Any:
"""Unwrap a 0-/1-element tensor (e.g. from DataLoader collation) to a Python scalar."""
return value.item() if isinstance(value, Tensor) else value
def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Any]:
return {
"prompt": _prompt_from_batch(batch=batch, config=config),
@@ -288,8 +297,8 @@ def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> d
"context": batch.get("context"),
"context_mask": batch.get("context_mask"),
"negative_prompt": batch.get("negative_prompt", config.negative_prompt),
"text_cfg_scale": float(batch.get("text_cfg_scale", config.text_cfg_scale)),
"num_inference_steps": int(batch.get("num_inference_steps", config.num_inference_steps)),
"text_cfg_scale": float(_scalar(batch.get("text_cfg_scale", config.text_cfg_scale))),
"num_inference_steps": int(_scalar(batch.get("num_inference_steps", config.num_inference_steps))),
"sigma_shift": batch.get("sigma_shift", config.sigma_shift),
"seed": batch.get("seed", config.inference_seed),
"rand_device": batch.get("rand_device", config.rand_device),
@@ -26,6 +26,8 @@ import torch.nn.functional as functional
from PIL import Image
from .wan_components import (
WAN22_DIFFUSERS_MODEL_ID,
WAN_T5_TOKENIZER,
build_wan_tokenizer,
load_pretrained_wan_text_encoder,
load_pretrained_wan_vae,
@@ -938,7 +940,8 @@ class FastWAM(torch.nn.Module):
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
model_id: str = "Wan-AI/Wan2.2-TI2V-5B",
tokenizer_model_id: str = "Wan-AI/Wan2.2-TI2V-5B",
tokenizer_model_id: str = WAN_T5_TOKENIZER,
text_encoder_model_id: str = WAN22_DIFFUSERS_MODEL_ID,
tokenizer_max_len: int = 512,
load_text_encoder: bool = True,
proprio_dim: int | None = None,
@@ -958,7 +961,6 @@ class FastWAM(torch.nn.Module):
raise ValueError("`video_dit_config` is required for FastWAM.from_wan22_pretrained().")
if "text_dim" not in video_dit_config:
raise ValueError("`video_dit_config['text_dim']` is required for FastWAM.")
del tokenizer_model_id # tokenizer is the stock UMT5 one (google/umt5-xxl)
# Custom MoT video DiT from the original Wan2.2 repo; frozen VAE / UMT5 from
# the diffusers conversion. This is the offline base-creation path; the
@@ -984,11 +986,13 @@ class FastWAM(torch.nn.Module):
vae = load_pretrained_wan_vae(torch_dtype=torch_dtype, device=device)
text_encoder = (
load_pretrained_wan_text_encoder(torch_dtype=torch_dtype, device=device)
load_pretrained_wan_text_encoder(
model_id=text_encoder_model_id, torch_dtype=torch_dtype, device=device
)
if load_text_encoder
else None
)
tokenizer = build_wan_tokenizer(tokenizer_max_len=tokenizer_max_len)
tokenizer = build_wan_tokenizer(model_id=tokenizer_model_id, tokenizer_max_len=tokenizer_max_len)
return cls(
video_expert=video_expert,
+15 -11
View File
@@ -36,10 +36,6 @@ if TYPE_CHECKING or _diffusers_available:
else:
AutoencoderKLWan = None
if TYPE_CHECKING:
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
@@ -109,8 +105,8 @@ class WanTokenizer:
return out.input_ids
def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
return WanTokenizer(name=WAN_T5_TOKENIZER, seq_len=int(tokenizer_max_len))
def build_wan_tokenizer(*, model_id: str = WAN_T5_TOKENIZER, tokenizer_max_len: int) -> WanTokenizer:
return WanTokenizer(name=model_id, seq_len=int(tokenizer_max_len))
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
@@ -120,12 +116,20 @@ def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVide
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
def load_pretrained_wan_text_encoder(
*,
model_id: str = WAN22_DIFFUSERS_MODEL_ID,
subfolder: str | None = "text_encoder",
torch_dtype: torch.dtype,
device: str,
) -> WanTextEncoder:
"""Load UMT5-XXL encoder weights (defaults to the Wan2.2 diffusers repo).
Must stay compatible with the tokenizer (see `build_wan_tokenizer`): the encoder's
embedding table is indexed by the tokenizer's vocabulary.
"""
require_package("transformers", extra="fastwam")
encoder = UMT5EncoderModel.from_pretrained(
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
)
encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder=subfolder, torch_dtype=torch_dtype)
return WanTextEncoder(dtype=torch_dtype, device=device, pretrained=encoder)
@@ -73,8 +73,9 @@ def test_config_validates_features_model_ids_and_saved_auto_route(tmp_path):
assert cfg.robot_state_feature.shape == (8,)
with pytest.raises(ValueError, match="image feature"):
FastWAMConfig(input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))})
with pytest.raises(ValueError, match="tokenizer_model_id"):
FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer")
assert FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer").tokenizer_model_id == (
"somebody/other-tokenizer"
)
def test_preprocessor_passes_images_through_and_postprocessor_toggles_actions(tmp_path):