mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-01 07:07:08 +00:00
make tokenizer/text-encoder model ids configurable + some nits
This commit is contained in:
@@ -28,7 +28,9 @@ from lerobot.optim import AdamWConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B"
|
||||
WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
|
||||
FASTWAM_BASE_MODEL_ID = "lerobot/fastwam_base"
|
||||
WAN_T5_TOKENIZER_ID = "google/umt5-xxl"
|
||||
|
||||
|
||||
_FASTWAM_VIDEO_BASE_COMPAT_KEYS = (
|
||||
@@ -99,8 +101,11 @@ def _coerce_enum(enum_cls: type, value: Any) -> Any:
|
||||
return value
|
||||
try:
|
||||
return enum_cls(value)
|
||||
except (TypeError, ValueError):
|
||||
return getattr(enum_cls, str(value), value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
member = getattr(enum_cls, str(value), None)
|
||||
if member is None:
|
||||
raise ValueError(f"Cannot coerce {value!r} into {enum_cls.__name__}.") from exc
|
||||
return member
|
||||
|
||||
|
||||
def _coerce_policy_features(features: dict[str, Any] | None) -> dict[str, PolicyFeature] | None:
|
||||
@@ -184,7 +189,8 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
image_size: tuple[int, int] = (224, 448)
|
||||
context_len: int = 128
|
||||
model_id: str = WAN22_MODEL_ID
|
||||
tokenizer_model_id: str = WAN22_MODEL_ID
|
||||
tokenizer_model_id: str = WAN_T5_TOKENIZER_ID
|
||||
text_encoder_model_id: str = WAN22_DIFFUSERS_MODEL_ID
|
||||
base_model_id: str | None = FASTWAM_BASE_MODEL_ID
|
||||
tokenizer_max_len: int = 128
|
||||
load_text_encoder: bool = True
|
||||
@@ -229,7 +235,6 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
super().__post_init__()
|
||||
self.image_size = tuple(self.image_size)
|
||||
self.model_id = _validate_wan_model_id(self.model_id, "model_id")
|
||||
self.tokenizer_model_id = _validate_wan_model_id(self.tokenizer_model_id, "tokenizer_model_id")
|
||||
self.input_features = _coerce_policy_features(self.input_features)
|
||||
self.output_features = _coerce_policy_features(self.output_features)
|
||||
self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions]
|
||||
|
||||
@@ -253,7 +253,9 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
|
||||
)
|
||||
text_encoder = (
|
||||
load_pretrained_wan_text_encoder(torch_dtype=dtype, device=device)
|
||||
load_pretrained_wan_text_encoder(
|
||||
model_id=config.text_encoder_model_id, torch_dtype=dtype, device=device
|
||||
)
|
||||
if config.load_text_encoder
|
||||
else None
|
||||
)
|
||||
@@ -263,7 +265,9 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
mot=mot,
|
||||
vae=load_pretrained_wan_vae(torch_dtype=dtype, device=device),
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=build_wan_tokenizer(tokenizer_max_len=config.tokenizer_max_len),
|
||||
tokenizer=build_wan_tokenizer(
|
||||
model_id=config.tokenizer_model_id, tokenizer_max_len=config.tokenizer_max_len
|
||||
),
|
||||
text_dim=int(config.video_dit_config["text_dim"]),
|
||||
proprio_dim=config.proprio_dim,
|
||||
device=device,
|
||||
@@ -279,6 +283,11 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
|
||||
def _scalar(value: Any) -> Any:
|
||||
"""Unwrap a 0-/1-element tensor (e.g. from DataLoader collation) to a Python scalar."""
|
||||
return value.item() if isinstance(value, Tensor) else value
|
||||
|
||||
|
||||
def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Any]:
|
||||
return {
|
||||
"prompt": _prompt_from_batch(batch=batch, config=config),
|
||||
@@ -288,8 +297,8 @@ def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> d
|
||||
"context": batch.get("context"),
|
||||
"context_mask": batch.get("context_mask"),
|
||||
"negative_prompt": batch.get("negative_prompt", config.negative_prompt),
|
||||
"text_cfg_scale": float(batch.get("text_cfg_scale", config.text_cfg_scale)),
|
||||
"num_inference_steps": int(batch.get("num_inference_steps", config.num_inference_steps)),
|
||||
"text_cfg_scale": float(_scalar(batch.get("text_cfg_scale", config.text_cfg_scale))),
|
||||
"num_inference_steps": int(_scalar(batch.get("num_inference_steps", config.num_inference_steps))),
|
||||
"sigma_shift": batch.get("sigma_shift", config.sigma_shift),
|
||||
"seed": batch.get("seed", config.inference_seed),
|
||||
"rand_device": batch.get("rand_device", config.rand_device),
|
||||
|
||||
@@ -26,6 +26,8 @@ import torch.nn.functional as functional
|
||||
from PIL import Image
|
||||
|
||||
from .wan_components import (
|
||||
WAN22_DIFFUSERS_MODEL_ID,
|
||||
WAN_T5_TOKENIZER,
|
||||
build_wan_tokenizer,
|
||||
load_pretrained_wan_text_encoder,
|
||||
load_pretrained_wan_vae,
|
||||
@@ -938,7 +940,8 @@ class FastWAM(torch.nn.Module):
|
||||
device: str = "cuda",
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
model_id: str = "Wan-AI/Wan2.2-TI2V-5B",
|
||||
tokenizer_model_id: str = "Wan-AI/Wan2.2-TI2V-5B",
|
||||
tokenizer_model_id: str = WAN_T5_TOKENIZER,
|
||||
text_encoder_model_id: str = WAN22_DIFFUSERS_MODEL_ID,
|
||||
tokenizer_max_len: int = 512,
|
||||
load_text_encoder: bool = True,
|
||||
proprio_dim: int | None = None,
|
||||
@@ -958,7 +961,6 @@ class FastWAM(torch.nn.Module):
|
||||
raise ValueError("`video_dit_config` is required for FastWAM.from_wan22_pretrained().")
|
||||
if "text_dim" not in video_dit_config:
|
||||
raise ValueError("`video_dit_config['text_dim']` is required for FastWAM.")
|
||||
del tokenizer_model_id # tokenizer is the stock UMT5 one (google/umt5-xxl)
|
||||
|
||||
# Custom MoT video DiT from the original Wan2.2 repo; frozen VAE / UMT5 from
|
||||
# the diffusers conversion. This is the offline base-creation path; the
|
||||
@@ -984,11 +986,13 @@ class FastWAM(torch.nn.Module):
|
||||
|
||||
vae = load_pretrained_wan_vae(torch_dtype=torch_dtype, device=device)
|
||||
text_encoder = (
|
||||
load_pretrained_wan_text_encoder(torch_dtype=torch_dtype, device=device)
|
||||
load_pretrained_wan_text_encoder(
|
||||
model_id=text_encoder_model_id, torch_dtype=torch_dtype, device=device
|
||||
)
|
||||
if load_text_encoder
|
||||
else None
|
||||
)
|
||||
tokenizer = build_wan_tokenizer(tokenizer_max_len=tokenizer_max_len)
|
||||
tokenizer = build_wan_tokenizer(model_id=tokenizer_model_id, tokenizer_max_len=tokenizer_max_len)
|
||||
|
||||
return cls(
|
||||
video_expert=video_expert,
|
||||
|
||||
@@ -36,10 +36,6 @@ if TYPE_CHECKING or _diffusers_available:
|
||||
else:
|
||||
AutoencoderKLWan = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
@@ -109,8 +105,8 @@ class WanTokenizer:
|
||||
return out.input_ids
|
||||
|
||||
|
||||
def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
|
||||
return WanTokenizer(name=WAN_T5_TOKENIZER, seq_len=int(tokenizer_max_len))
|
||||
def build_wan_tokenizer(*, model_id: str = WAN_T5_TOKENIZER, tokenizer_max_len: int) -> WanTokenizer:
|
||||
return WanTokenizer(name=model_id, seq_len=int(tokenizer_max_len))
|
||||
|
||||
|
||||
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
|
||||
@@ -120,12 +116,20 @@ def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVide
|
||||
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
|
||||
|
||||
|
||||
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
|
||||
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
|
||||
def load_pretrained_wan_text_encoder(
|
||||
*,
|
||||
model_id: str = WAN22_DIFFUSERS_MODEL_ID,
|
||||
subfolder: str | None = "text_encoder",
|
||||
torch_dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> WanTextEncoder:
|
||||
"""Load UMT5-XXL encoder weights (defaults to the Wan2.2 diffusers repo).
|
||||
|
||||
Must stay compatible with the tokenizer (see `build_wan_tokenizer`): the encoder's
|
||||
embedding table is indexed by the tokenizer's vocabulary.
|
||||
"""
|
||||
require_package("transformers", extra="fastwam")
|
||||
encoder = UMT5EncoderModel.from_pretrained(
|
||||
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
|
||||
)
|
||||
encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder=subfolder, torch_dtype=torch_dtype)
|
||||
return WanTextEncoder(dtype=torch_dtype, device=device, pretrained=encoder)
|
||||
|
||||
|
||||
|
||||
@@ -73,8 +73,9 @@ def test_config_validates_features_model_ids_and_saved_auto_route(tmp_path):
|
||||
assert cfg.robot_state_feature.shape == (8,)
|
||||
with pytest.raises(ValueError, match="image feature"):
|
||||
FastWAMConfig(input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))})
|
||||
with pytest.raises(ValueError, match="tokenizer_model_id"):
|
||||
FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer")
|
||||
assert FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer").tokenizer_model_id == (
|
||||
"somebody/other-tokenizer"
|
||||
)
|
||||
|
||||
|
||||
def test_preprocessor_passes_images_through_and_postprocessor_toggles_actions(tmp_path):
|
||||
|
||||
Reference in New Issue
Block a user