From 343ab5f99ca29f8d1db7edd65432f1e475122065 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Tue, 23 Jun 2026 09:07:16 +0000 Subject: [PATCH] make tokenizer/text-encoder model ids configurable + some nits --- .../policies/fastwam/configuration_fastwam.py | 13 +++++++--- .../policies/fastwam/modeling_fastwam.py | 17 +++++++++--- .../policies/fastwam/modular_fastwam.py | 12 ++++++--- .../policies/fastwam/wan_components.py | 26 +++++++++++-------- tests/policies/fastwam/test_fastwam_policy.py | 5 ++-- 5 files changed, 48 insertions(+), 25 deletions(-) diff --git a/src/lerobot/policies/fastwam/configuration_fastwam.py b/src/lerobot/policies/fastwam/configuration_fastwam.py index 0e28efb37..a3ef4f602 100644 --- a/src/lerobot/policies/fastwam/configuration_fastwam.py +++ b/src/lerobot/policies/fastwam/configuration_fastwam.py @@ -28,7 +28,9 @@ from lerobot.optim import AdamWConfig from lerobot.utils.constants import ACTION, OBS_STATE WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B" +WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" FASTWAM_BASE_MODEL_ID = "lerobot/fastwam_base" +WAN_T5_TOKENIZER_ID = "google/umt5-xxl" _FASTWAM_VIDEO_BASE_COMPAT_KEYS = ( @@ -99,8 +101,11 @@ def _coerce_enum(enum_cls: type, value: Any) -> Any: return value try: return enum_cls(value) - except (TypeError, ValueError): - return getattr(enum_cls, str(value), value) + except (TypeError, ValueError) as exc: + member = getattr(enum_cls, str(value), None) + if member is None: + raise ValueError(f"Cannot coerce {value!r} into {enum_cls.__name__}.") from exc + return member def _coerce_policy_features(features: dict[str, Any] | None) -> dict[str, PolicyFeature] | None: @@ -184,7 +189,8 @@ class FastWAMConfig(PreTrainedConfig): image_size: tuple[int, int] = (224, 448) context_len: int = 128 model_id: str = WAN22_MODEL_ID - tokenizer_model_id: str = WAN22_MODEL_ID + tokenizer_model_id: str = WAN_T5_TOKENIZER_ID + text_encoder_model_id: str = WAN22_DIFFUSERS_MODEL_ID base_model_id: str | None = FASTWAM_BASE_MODEL_ID tokenizer_max_len: int = 128 load_text_encoder: bool = True @@ -229,7 +235,6 @@ class FastWAMConfig(PreTrainedConfig): super().__post_init__() self.image_size = tuple(self.image_size) self.model_id = _validate_wan_model_id(self.model_id, "model_id") - self.tokenizer_model_id = _validate_wan_model_id(self.tokenizer_model_id, "tokenizer_model_id") self.input_features = _coerce_policy_features(self.input_features) self.output_features = _coerce_policy_features(self.output_features) self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions] diff --git a/src/lerobot/policies/fastwam/modeling_fastwam.py b/src/lerobot/policies/fastwam/modeling_fastwam.py index 292d90bd8..9822c9834 100644 --- a/src/lerobot/policies/fastwam/modeling_fastwam.py +++ b/src/lerobot/policies/fastwam/modeling_fastwam.py @@ -253,7 +253,9 @@ class FastWAMPolicy(PreTrainedPolicy): mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn, ) text_encoder = ( - load_pretrained_wan_text_encoder(torch_dtype=dtype, device=device) + load_pretrained_wan_text_encoder( + model_id=config.text_encoder_model_id, torch_dtype=dtype, device=device + ) if config.load_text_encoder else None ) @@ -263,7 +265,9 @@ class FastWAMPolicy(PreTrainedPolicy): mot=mot, vae=load_pretrained_wan_vae(torch_dtype=dtype, device=device), text_encoder=text_encoder, - tokenizer=build_wan_tokenizer(tokenizer_max_len=config.tokenizer_max_len), + tokenizer=build_wan_tokenizer( + model_id=config.tokenizer_model_id, tokenizer_max_len=config.tokenizer_max_len + ), text_dim=int(config.video_dit_config["text_dim"]), proprio_dim=config.proprio_dim, device=device, @@ -279,6 +283,11 @@ class FastWAMPolicy(PreTrainedPolicy): ) +def _scalar(value: Any) -> Any: + """Unwrap a 0-/1-element tensor (e.g. from DataLoader collation) to a Python scalar.""" + return value.item() if isinstance(value, Tensor) else value + + def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Any]: return { "prompt": _prompt_from_batch(batch=batch, config=config), @@ -288,8 +297,8 @@ def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> d "context": batch.get("context"), "context_mask": batch.get("context_mask"), "negative_prompt": batch.get("negative_prompt", config.negative_prompt), - "text_cfg_scale": float(batch.get("text_cfg_scale", config.text_cfg_scale)), - "num_inference_steps": int(batch.get("num_inference_steps", config.num_inference_steps)), + "text_cfg_scale": float(_scalar(batch.get("text_cfg_scale", config.text_cfg_scale))), + "num_inference_steps": int(_scalar(batch.get("num_inference_steps", config.num_inference_steps))), "sigma_shift": batch.get("sigma_shift", config.sigma_shift), "seed": batch.get("seed", config.inference_seed), "rand_device": batch.get("rand_device", config.rand_device), diff --git a/src/lerobot/policies/fastwam/modular_fastwam.py b/src/lerobot/policies/fastwam/modular_fastwam.py index 8d3df9c91..d82c8b69d 100644 --- a/src/lerobot/policies/fastwam/modular_fastwam.py +++ b/src/lerobot/policies/fastwam/modular_fastwam.py @@ -26,6 +26,8 @@ import torch.nn.functional as functional from PIL import Image from .wan_components import ( + WAN22_DIFFUSERS_MODEL_ID, + WAN_T5_TOKENIZER, build_wan_tokenizer, load_pretrained_wan_text_encoder, load_pretrained_wan_vae, @@ -938,7 +940,8 @@ class FastWAM(torch.nn.Module): device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16, model_id: str = "Wan-AI/Wan2.2-TI2V-5B", - tokenizer_model_id: str = "Wan-AI/Wan2.2-TI2V-5B", + tokenizer_model_id: str = WAN_T5_TOKENIZER, + text_encoder_model_id: str = WAN22_DIFFUSERS_MODEL_ID, tokenizer_max_len: int = 512, load_text_encoder: bool = True, proprio_dim: int | None = None, @@ -958,7 +961,6 @@ class FastWAM(torch.nn.Module): raise ValueError("`video_dit_config` is required for FastWAM.from_wan22_pretrained().") if "text_dim" not in video_dit_config: raise ValueError("`video_dit_config['text_dim']` is required for FastWAM.") - del tokenizer_model_id # tokenizer is the stock UMT5 one (google/umt5-xxl) # Custom MoT video DiT from the original Wan2.2 repo; frozen VAE / UMT5 from # the diffusers conversion. This is the offline base-creation path; the @@ -984,11 +986,13 @@ class FastWAM(torch.nn.Module): vae = load_pretrained_wan_vae(torch_dtype=torch_dtype, device=device) text_encoder = ( - load_pretrained_wan_text_encoder(torch_dtype=torch_dtype, device=device) + load_pretrained_wan_text_encoder( + model_id=text_encoder_model_id, torch_dtype=torch_dtype, device=device + ) if load_text_encoder else None ) - tokenizer = build_wan_tokenizer(tokenizer_max_len=tokenizer_max_len) + tokenizer = build_wan_tokenizer(model_id=tokenizer_model_id, tokenizer_max_len=tokenizer_max_len) return cls( video_expert=video_expert, diff --git a/src/lerobot/policies/fastwam/wan_components.py b/src/lerobot/policies/fastwam/wan_components.py index 6d16cf5c7..a69f21fe0 100644 --- a/src/lerobot/policies/fastwam/wan_components.py +++ b/src/lerobot/policies/fastwam/wan_components.py @@ -36,10 +36,6 @@ if TYPE_CHECKING or _diffusers_available: else: AutoencoderKLWan = None -if TYPE_CHECKING: - from .wan_adapters import WanVideoVAE38 - from .wan_video_dit import WanVideoDiT - from .wan_adapters import WanVideoVAE38 from .wan_video_dit import WanVideoDiT @@ -109,8 +105,8 @@ class WanTokenizer: return out.input_ids -def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer: - return WanTokenizer(name=WAN_T5_TOKENIZER, seq_len=int(tokenizer_max_len)) +def build_wan_tokenizer(*, model_id: str = WAN_T5_TOKENIZER, tokenizer_max_len: int) -> WanTokenizer: + return WanTokenizer(name=model_id, seq_len=int(tokenizer_max_len)) def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38: @@ -120,12 +116,20 @@ def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVide return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae) -def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder: - """Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation).""" +def load_pretrained_wan_text_encoder( + *, + model_id: str = WAN22_DIFFUSERS_MODEL_ID, + subfolder: str | None = "text_encoder", + torch_dtype: torch.dtype, + device: str, +) -> WanTextEncoder: + """Load UMT5-XXL encoder weights (defaults to the Wan2.2 diffusers repo). + + Must stay compatible with the tokenizer (see `build_wan_tokenizer`): the encoder's + embedding table is indexed by the tokenizer's vocabulary. + """ require_package("transformers", extra="fastwam") - encoder = UMT5EncoderModel.from_pretrained( - WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype - ) + encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder=subfolder, torch_dtype=torch_dtype) return WanTextEncoder(dtype=torch_dtype, device=device, pretrained=encoder) diff --git a/tests/policies/fastwam/test_fastwam_policy.py b/tests/policies/fastwam/test_fastwam_policy.py index bbffdf973..2d132e985 100644 --- a/tests/policies/fastwam/test_fastwam_policy.py +++ b/tests/policies/fastwam/test_fastwam_policy.py @@ -73,8 +73,9 @@ def test_config_validates_features_model_ids_and_saved_auto_route(tmp_path): assert cfg.robot_state_feature.shape == (8,) with pytest.raises(ValueError, match="image feature"): FastWAMConfig(input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))}) - with pytest.raises(ValueError, match="tokenizer_model_id"): - FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer") + assert FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer").tokenizer_model_id == ( + "somebody/other-tokenizer" + ) def test_preprocessor_passes_images_through_and_postprocessor_toggles_actions(tmp_path):