[MAX] Add AutoencoderKL VAE implementation for Flux.2 pipeline#5889
[MAX] Add AutoencoderKL VAE implementation for Flux.2 pipeline#5889byungchul-sqzb wants to merge 15 commits intomodular:mainfrom
Conversation
katelyncaldwell
left a comment
There was a problem hiding this comment.
A few comments to address, but then good to merge on my end!
| @@ -0,0 +1,325 @@ | |||
| # ===----------------------------------------------------------------------=== # | |||
| # Copyright (c) 2025, Modular Inc. All rights reserved. | |||
|
|
||
| return converted_weights | ||
|
|
||
| def load_model(self) -> Any: |
There was a problem hiding this comment.
A lot of this function seems directly copied from BaseAutoencoderModel. Would it make sense to refactor and avoid the duplication?
There was a problem hiding this comment.
I’ve now refactored it to eliminate the redundancy as suggested:
- Weight Handling: Moved weight dtype check and conversion logic into the BaseAutoencoderModel so it’s handled centrally.
- Encoder Logic: Integrated quant_conv logic into the common VAE Encoder class, controlled by a boolean flag.
| return self.decoder(z, temb) | ||
|
|
||
|
|
||
| class BatchNormStats: |
There was a problem hiding this comment.
Should this be a @dataclass?
| autoencoder_class=AutoencoderKLFlux2, | ||
| ) | ||
|
|
||
| def convert_weights_to_target_dtype( |
There was a problem hiding this comment.
Consider adding more dtype checks to ensure we're only casting between float types (e.g. if we want to support a quantized dtype, this weight adapter should not be a source of subtle dtype conversion bugs)
|
|
||
| Args: | ||
| hidden_states: Input tensor of shape [N, C, H, W]. | ||
| *args: Additional positional arguments (ignored, kept for compatibility). |
There was a problem hiding this comment.
I would prefer to avoid this pattern if possible. I am assuming this is for compatibility with diffusers? Can we remove it or will that lead to challenges?
0950704 to
999f8dc
Compare
katelyncaldwell
left a comment
There was a problem hiding this comment.
looks great! I am going to merge 😄
|
!sync |
Overview
This PR adds the VAE (Variational Autoencoder) model stack required for the Flux2 pipeline. It introduces a new AutoencoderKLFlux2 implementation with Flux2-specific configuration, encoder/decoder components, and supporting layers for image-to-latent encoding and latent-to-image decoding.
This PR implements VAE encoder functionality for image conditional logics in the Flux2 pipeline.
What's included
New AutoencoderKLFlux2 architecture (
autoencoder_kl_flux2.py):Enhanced VAE components (
vae.py):New downsampling layer (
layers/downsampling.py):Downsample2Dmodule for spatial resolution reductionNOTES