Skip to content

[MAX] Add UMT5 Encoder Architecture#5936

Draft
jglee-sqbits wants to merge 1 commit intomodular:mainfrom
SqueezeBits:add/wan-pipeline/umt5-encoder
Draft

[MAX] Add UMT5 Encoder Architecture#5936
jglee-sqbits wants to merge 1 commit intomodular:mainfrom
SqueezeBits:add/wan-pipeline/umt5-encoder

Conversation

@jglee-sqbits
Copy link

Summary

This PR adds a new umt5 architecture under max/python/max/pipelines/architectures/umt5 for Wan text encoding use cases, aligned with Hugging Face UMT5 behavior and Wan2.2 text-encoder defaults.

It includes:

  • A MAX UMT5 encoder implementation.
  • UMT5 config defaults matching Wan-AI/Wan2.2-T2V-A14B-Diffusers text encoder config.
  • Weight adapter utilities for shared/encoder embedding compatibility.
  • Integration tests comparing MAX outputs against Transformers UMT5 outputs.
  • A diffusers-like prompt embedding parity test path (mirroring _get_t5_prompt_embeds post-processing behavior).

Motivation

Wan pipelines rely on UMT5 text encoding and specific prompt embedding post-processing.
To support Wan-compatible behavior in MAX, we need a dedicated UMT5 architecture with validated HF parity and Wan-like embedding flow checks.

What Changed

New architecture: architectures/umt5

  • Added:
    • max/python/max/pipelines/architectures/umt5/__init__.py
    • max/python/max/pipelines/architectures/umt5/model.py
    • max/python/max/pipelines/architectures/umt5/model_config.py
    • max/python/max/pipelines/architectures/umt5/umt5.py
    • max/python/max/pipelines/architectures/umt5/weight_adapters.py

UMT5 config defaults (Wan-aligned)

  • model_config.py defaults were set to Wan text-encoder-compatible values, including:
    • vocab_size=256384
    • d_model=4096
    • d_ff=10240
    • num_layers=24
    • num_heads=64
    • feed_forward_proj="gated-gelu"
    • tokenizer_class="T5Tokenizer"
    • tie_word_embeddings=False
    • plus other matching fields from Wan text encoder config.

HF compatibility and behavior

  • Implemented UMT5 encoder components (attention, FFN, layer norm, stack, encoder model).
  • Added support for expected input signature (input_ids, attention_mask).
  • Added embedding-key adaptation for checkpoints that include only one of:
    • shared.weight
    • encoder.embed_tokens.weight

Tests

./bazelw test --cache_test_results=no //max/tests/integration/architectures/umt5:umt5

  • Added:
    • max/tests/integration/architectures/umt5/BUILD.bazel
    • max/tests/integration/architectures/umt5/test_weight_adapters.py
    • max/tests/integration/architectures/umt5/test_encoder.py
  • test_encoder.py validates:
    • MAX UMT5 vs HF transformers UMT5 hidden-state parity.
    • NaN/Inf safety checks.
    • diffusers-like prompt embedding path parity:
      • trim by attention-mask length
      • pad to max sequence length
      • repeat for num_videos_per_prompt

Notes

  • Current implementation targets encoder behavior for Wan text encoding workflows.
  • Decoder/caching paths are intentionally not implemented in this change.

@jglee-sqbits jglee-sqbits requested a review from a team as a code owner February 11, 2026 04:38
@jglee-sqbits jglee-sqbits marked this pull request as draft February 11, 2026 04:39
@jglee-sqbits jglee-sqbits changed the title [MAX] Add UMT5 Encoder Architecture for Wan Pipeline [MAX] Add UMT5 Encoder Architecture Feb 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant