Skip to content

[MAX] Add AutoencoderKL VAE implementation for Flux.2 pipeline#5889

Open
byungchul-sqzb wants to merge 15 commits intomodular:mainfrom
SqueezeBits:add/flux2-pipeline/models-vae
Open

[MAX] Add AutoencoderKL VAE implementation for Flux.2 pipeline#5889
byungchul-sqzb wants to merge 15 commits intomodular:mainfrom
SqueezeBits:add/flux2-pipeline/models-vae

Conversation

@byungchul-sqzb
Copy link
Contributor

Overview

This PR adds the VAE (Variational Autoencoder) model stack required for the Flux2 pipeline. It introduces a new AutoencoderKLFlux2 implementation with Flux2-specific configuration, encoder/decoder components, and supporting layers for image-to-latent encoding and latent-to-image decoding.

This PR implements VAE encoder functionality for image conditional logics in the Flux2 pipeline.

What's included

  • New AutoencoderKLFlux2 architecture (autoencoder_kl_flux2.py):

    • Support for BatchNorm statistics for latent patchification
    • Image encoding to latents and decoding from latents
  • Enhanced VAE components (vae.py):

    • Extended Encoder and Decoder implementations
    • Support for Flux2's latent patchification process
    • Image conditional encoding logic
  • New downsampling layer (layers/downsampling.py):

    • Downsample2D module for spatial resolution reduction
    • Optional convolution-based downsampling

NOTES

  • This model implementation does NOT serve as a standalone executable model. It is expected to be executed within the Flux2 pipeline.
  • The VAE encoder supports image conditional logics for text-to-image generation workflows.

@byungchul-sqzb byungchul-sqzb marked this pull request as ready for review February 3, 2026 07:13
Copy link
Contributor

@katelyncaldwell katelyncaldwell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments to address, but then good to merge on my end!

@@ -0,0 +1,325 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2026 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Commit


return converted_weights

def load_model(self) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this function seems directly copied from BaseAutoencoderModel. Would it make sense to refactor and avoid the duplication?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve now refactored it to eliminate the redundancy as suggested:

  • Weight Handling: Moved weight dtype check and conversion logic into the BaseAutoencoderModel so it’s handled centrally.
  • Encoder Logic: Integrated quant_conv logic into the common VAE Encoder class, controlled by a boolean flag.

return self.decoder(z, temb)


class BatchNormStats:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a @dataclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Commit

autoencoder_class=AutoencoderKLFlux2,
)

def convert_weights_to_target_dtype(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding more dtype checks to ensure we're only casting between float types (e.g. if we want to support a quantized dtype, this weight adapter should not be a source of subtle dtype conversion bugs)

Copy link
Contributor Author

@byungchul-sqzb byungchul-sqzb Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Commit


Args:
hidden_states: Input tensor of shape [N, C, H, W].
*args: Additional positional arguments (ignored, kept for compatibility).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to avoid this pattern if possible. I am assuming this is for compatibility with diffusers? Can we remove it or will that lead to challenges?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Commit

@byungchul-sqzb byungchul-sqzb force-pushed the add/flux2-pipeline/models-vae branch from 0950704 to 999f8dc Compare February 5, 2026 04:07
Copy link
Contributor

@katelyncaldwell katelyncaldwell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! I am going to merge 😄

@katelyncaldwell
Copy link
Contributor

!sync

@modularbot modularbot added the imported-internally Signals that a given pull request has been imported internally. label Feb 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

imported-internally Signals that a given pull request has been imported internally.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants