From 8283e1badeb7b690726976abe0fe6239522ef888 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 11 Jan 2026 17:35:01 +0800 Subject: [PATCH] enable flash attn by default --- README.md | 1 + conditioner.hpp | 39 ++++++++++++++++++++++++++++++++++++++ diffusion_model.hpp | 14 +++++++------- docs/flux2.md | 2 +- docs/ovis_image.md | 2 +- docs/performance.md | 19 ------------------- docs/qwen_image.md | 2 +- docs/qwen_image_edit.md | 6 +++--- docs/troubleshooting.md | 3 +++ docs/z_image.md | 2 +- examples/cli/README.md | 2 +- examples/common/common.hpp | 12 ++++++------ examples/server/README.md | 2 +- ggml_extend.hpp | 2 +- stable-diffusion.cpp | 34 +++++++++++++++++++++------------ stable-diffusion.h | 2 +- vae.hpp | 2 +- wan.hpp | 4 ++-- 18 files changed, 92 insertions(+), 58 deletions(-) create mode 100644 docs/troubleshooting.md diff --git a/README.md b/README.md index 4e536880d..cba808683 100644 --- a/README.md +++ b/README.md @@ -144,6 +144,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe - [Docker](./docs/docker.md) - [Quantization and GGUF](./docs/quantization_and_gguf.md) - [Inference acceleration via caching](./docs/caching.md) +- [Troubleshooting](./docs/troubleshooting.md) ## Bindings diff --git a/conditioner.hpp b/conditioner.hpp index b6d5646a7..41b2e3489 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -34,6 +34,7 @@ struct Conditioner { virtual void free_params_buffer() = 0; virtual void get_param_tensors(std::map& tensors) = 0; virtual size_t get_params_buffer_size() = 0; + virtual void set_flash_attention_enabled(bool enabled) = 0; virtual void set_weight_adapter(const std::shared_ptr& adapter) {} virtual std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, int n_threads, @@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { return buffer_size; } + void set_flash_attention_enabled(bool enabled) override { + text_model->set_flash_attention_enabled(enabled); + if (sd_version_is_sdxl(version)) { + text_model2->set_flash_attention_enabled(enabled); + } + } + void set_weight_adapter(const std::shared_ptr& adapter) override { text_model->set_weight_adapter(adapter); if (sd_version_is_sdxl(version)) { @@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner { return buffer_size; } + void set_flash_attention_enabled(bool enabled) override { + if (clip_l) { + clip_l->set_flash_attention_enabled(enabled); + } + if (clip_g) { + clip_g->set_flash_attention_enabled(enabled); + } + if (t5) { + t5->set_flash_attention_enabled(enabled); + } + } + void set_weight_adapter(const std::shared_ptr& adapter) override { if (clip_l) { clip_l->set_weight_adapter(adapter); @@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner { return buffer_size; } + void set_flash_attention_enabled(bool enabled) override { + if (clip_l) { + clip_l->set_flash_attention_enabled(enabled); + } + if (t5) { + t5->set_flash_attention_enabled(enabled); + } + } + void set_weight_adapter(const std::shared_ptr& adapter) { if (clip_l) { clip_l->set_weight_adapter(adapter); @@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner { return buffer_size; } + void set_flash_attention_enabled(bool enabled) override { + if (t5) { + t5->set_flash_attention_enabled(enabled); + } + } + void set_weight_adapter(const std::shared_ptr& adapter) override { if (t5) { t5->set_weight_adapter(adapter); @@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner { return buffer_size; } + void set_flash_attention_enabled(bool enabled) override { + llm->set_flash_attention_enabled(enabled); + } + void set_weight_adapter(const std::shared_ptr& adapter) override { if (llm) { llm->set_weight_adapter(adapter); diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 06cbecc28..3293ba9b7 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -38,7 +38,7 @@ struct DiffusionModel { virtual size_t get_params_buffer_size() = 0; virtual void set_weight_adapter(const std::shared_ptr& adapter){}; virtual int64_t get_adm_in_channels() = 0; - virtual void set_flash_attn_enabled(bool enabled) = 0; + virtual void set_flash_attention_enabled(bool enabled) = 0; virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; }; @@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel { return unet.unet.adm_in_channels; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { unet.set_flash_attention_enabled(enabled); } @@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel { return 768 + 1280; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { mmdit.set_flash_attention_enabled(enabled); } @@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel { return 768; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { flux.set_flash_attention_enabled(enabled); } @@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel { return 768; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { wan.set_flash_attention_enabled(enabled); } @@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel { return 768; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { qwen_image.set_flash_attention_enabled(enabled); } @@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel { return 768; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { z_image.set_flash_attention_enabled(enabled); } diff --git a/docs/flux2.md b/docs/flux2.md index 0c2c6d2b7..111cba8e1 100644 --- a/docs/flux2.md +++ b/docs/flux2.md @@ -12,7 +12,7 @@ ## Examples ``` -.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --offload-to-cpu ``` flux2 example diff --git a/docs/ovis_image.md b/docs/ovis_image.md index 5bd3e8ea3..711a7ddcf 100644 --- a/docs/ovis_image.md +++ b/docs/ovis_image.md @@ -13,7 +13,7 @@ ## Examples ``` -.\bin\Release\sd-cli.exe --diffusion-model ovis_image-Q4_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\ovis_2.5.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa +.\bin\Release\sd-cli.exe --diffusion-model ovis_image-Q4_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\ovis_2.5.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu ``` ovis image example \ No newline at end of file diff --git a/docs/performance.md b/docs/performance.md index 0c4735e0b..cf1d0c3d2 100644 --- a/docs/performance.md +++ b/docs/performance.md @@ -1,22 +1,3 @@ -## Use Flash Attention to save memory and improve speed. - -Enabling flash attention for the diffusion model reduces memory usage by varying amounts of MB. -eg.: - - flux 768x768 ~600mb - - SD2 768x768 ~1400mb - -For most backends, it slows things down, but for cuda it generally speeds it up too. -At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal). - -Run by adding `--diffusion-fa` to the arguments and watch for: -``` -[INFO ] stable-diffusion.cpp:312 - Using flash attention in the diffusion model -``` -and the compute buffer shrink in the debug log: -``` -[DEBUG] ggml_extend.hpp:1004 - flux compute buffer size: 650.00 MB(VRAM) -``` - ## Offload weights to the CPU to save VRAM without reducing generation speed. Using `--offload-to-cpu` allows you to offload weights to the CPU, saving VRAM without reducing generation speed. diff --git a/docs/qwen_image.md b/docs/qwen_image.md index f12421f47..3df4978d6 100644 --- a/docs/qwen_image.md +++ b/docs/qwen_image.md @@ -14,7 +14,7 @@ ## Examples ``` -.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3 +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --flow-shift 3 ``` qwen example diff --git a/docs/qwen_image_edit.md b/docs/qwen_image_edit.md index 4a8b01728..16e470882 100644 --- a/docs/qwen_image_edit.md +++ b/docs/qwen_image_edit.md @@ -23,7 +23,7 @@ ### Qwen Image Edit ``` -.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453 +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453 ``` qwen_image_edit @@ -32,7 +32,7 @@ ### Qwen Image Edit 2509 ``` -.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --llm_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'" +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --llm_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'" ``` qwen_image_edit_2509 @@ -42,7 +42,7 @@ To use the new Qwen Image Edit 2511 mode, the `--qwen-image-zero-cond-t` flag must be enabled; otherwise, image editing quality will degrade significantly. ``` -.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-edit-2511-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --qwen-image-zero-cond-t +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-edit-2511-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --qwen-image-zero-cond-t ``` qwen_image_edit_2509 \ No newline at end of file diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md new file mode 100644 index 000000000..51ff33f8d --- /dev/null +++ b/docs/troubleshooting.md @@ -0,0 +1,3 @@ +## Try `--disable-fa` + +By default, **stable-diffusion.cpp** uses Flash Attention to improve generation speed and optimize GPU memory usage. However, on some backends, Flash Attention may cause unexpected issues, such as generating completely black images. In such cases, you can try disabling Flash Attention by using `--disable-fa`. \ No newline at end of file diff --git a/docs/z_image.md b/docs/z_image.md index 122f1f205..6ba8ba974 100644 --- a/docs/z_image.md +++ b/docs/z_image.md @@ -16,7 +16,7 @@ You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or ev ## Examples ``` -.\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512 +.\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu -H 1024 -W 512 ``` z-image example diff --git a/examples/cli/README.md b/examples/cli/README.md index 84dd5c716..0ec39bd95 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -52,7 +52,7 @@ Context Options: --control-net-cpu keep controlnet in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram) --vae-on-cpu keep vae in cpu (for low vram) - --diffusion-fa use flash attention in the diffusion model + --disable-fa disable flash attention --diffusion-conv-direct use ggml_conv2d_direct in the diffusion model --vae-conv-direct use ggml_conv2d_direct in the vae model --circular enable circular padding for convolutions diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 82328bccb..8880c7689 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -457,7 +457,7 @@ struct SDContextParams { bool control_net_cpu = false; bool clip_on_cpu = false; bool vae_on_cpu = false; - bool diffusion_flash_attn = false; + bool flash_attn = true; bool diffusion_conv_direct = false; bool vae_conv_direct = false; @@ -616,9 +616,9 @@ struct SDContextParams { "keep vae in cpu (for low vram)", true, &vae_on_cpu}, {"", - "--diffusion-fa", - "use flash attention in the diffusion model", - true, &diffusion_flash_attn}, + "--disable-fa", + "disable flash attention", + false, &flash_attn}, {"", "--diffusion-conv-direct", "use ggml_conv2d_direct in the diffusion model", @@ -904,7 +904,7 @@ struct SDContextParams { << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n" << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n" << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n" - << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" + << " flash_attn: " << (flash_attn ? "true" : "false") << ",\n" << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" << " circular: " << (circular ? "true" : "false") << ",\n" @@ -968,7 +968,7 @@ struct SDContextParams { clip_on_cpu, control_net_cpu, vae_on_cpu, - diffusion_flash_attn, + flash_attn, taesd_preview, diffusion_conv_direct, vae_conv_direct, diff --git a/examples/server/README.md b/examples/server/README.md index 7e6681570..99328361e 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -44,7 +44,7 @@ Context Options: --clip-on-cpu keep clip in cpu (for low vram) --vae-on-cpu keep vae in cpu (for low vram) --mmap whether to memory-map model - --diffusion-fa use flash attention in the diffusion model + --disable-fa disable flash attention --diffusion-conv-direct use ggml_conv2d_direct in the diffusion model --vae-conv-direct use ggml_conv2d_direct in the vae model --circular enable circular padding for convolutions diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 6f498ffa7..b84ff4427 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2594,7 +2594,7 @@ class MultiheadAttention : public GGMLBlock { v = v_proj->forward(ctx, x); } - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask, false, false); // [N, n_token, embed_dim] x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] return x; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 60bcba4d3..ba80d5aa1 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -435,7 +435,7 @@ class StableDiffusionGGML { } } if (is_chroma) { - if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) { + if (sd_ctx_params->flash_attn && sd_ctx_params->chroma_use_dit_mask) { LOG_WARN( "!!!It looks like you are using Chroma with flash attention. " "This is currently unsupported. " @@ -561,14 +561,6 @@ class StableDiffusionGGML { } } - if (sd_ctx_params->diffusion_flash_attn) { - LOG_INFO("Using flash attention in the diffusion model"); - diffusion_model->set_flash_attn_enabled(true); - if (high_noise_diffusion_model) { - high_noise_diffusion_model->set_flash_attn_enabled(true); - } - } - cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors); @@ -712,6 +704,24 @@ class StableDiffusionGGML { pmid_model->get_param_tensors(tensors, "pmid"); } + if (sd_ctx_params->flash_attn) { + LOG_INFO("Using flash attention"); + diffusion_model->set_flash_attention_enabled(true); + if (high_noise_diffusion_model) { + high_noise_diffusion_model->set_flash_attention_enabled(true); + } + cond_stage_model->set_flash_attention_enabled(true); + if (clip_vision) { + clip_vision->set_flash_attention_enabled(true); + } + if (first_stage_model) { + first_stage_model->set_flash_attention_enabled(true); + } + if (tae_first_stage) { + tae_first_stage->set_flash_attention_enabled(true); + } + } + diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); if (high_noise_diffusion_model) { high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); @@ -2884,7 +2894,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { sd_ctx_params->keep_clip_on_cpu = false; sd_ctx_params->keep_control_net_on_cpu = false; sd_ctx_params->keep_vae_on_cpu = false; - sd_ctx_params->diffusion_flash_attn = false; + sd_ctx_params->flash_attn = false; sd_ctx_params->circular_x = false; sd_ctx_params->circular_y = false; sd_ctx_params->chroma_use_dit_mask = true; @@ -2925,7 +2935,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "keep_clip_on_cpu: %s\n" "keep_control_net_on_cpu: %s\n" "keep_vae_on_cpu: %s\n" - "diffusion_flash_attn: %s\n" + "flash_attn: %s\n" "circular_x: %s\n" "circular_y: %s\n" "chroma_use_dit_mask: %s\n" @@ -2956,7 +2966,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { BOOL_STR(sd_ctx_params->keep_clip_on_cpu), BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), BOOL_STR(sd_ctx_params->keep_vae_on_cpu), - BOOL_STR(sd_ctx_params->diffusion_flash_attn), + BOOL_STR(sd_ctx_params->flash_attn), BOOL_STR(sd_ctx_params->circular_x), BOOL_STR(sd_ctx_params->circular_y), BOOL_STR(sd_ctx_params->chroma_use_dit_mask), diff --git a/stable-diffusion.h b/stable-diffusion.h index 8f040d2bd..e0fb8ae04 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -186,7 +186,7 @@ typedef struct { bool keep_clip_on_cpu; bool keep_control_net_on_cpu; bool keep_vae_on_cpu; - bool diffusion_flash_attn; + bool flash_attn; bool tae_preview_only; bool diffusion_conv_direct; bool vae_conv_direct; diff --git a/vae.hpp b/vae.hpp index 232500295..b69282ae3 100644 --- a/vae.hpp +++ b/vae.hpp @@ -141,7 +141,7 @@ class AttnBlock : public UnaryBlock { v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels] } - h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); + h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, ctx->flash_attn_enabled); if (use_linear) { h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] diff --git a/wan.hpp b/wan.hpp index 3ade14bfe..d6c11bff9 100644 --- a/wan.hpp +++ b/wan.hpp @@ -572,8 +572,8 @@ namespace WAN { auto v = qkv_vec[2]; v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w] - v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c] + v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, ctx->flash_attn_enabled); // [t, h * w, c] x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]