From c205042b49b36a0d12878649984397357347e7d9 Mon Sep 17 00:00:00 2001 From: Kira Selby Date: Sat, 30 Aug 2025 08:22:37 -0400 Subject: [PATCH] Add support for tensor_buft_overrides for more finegrained control of which layers are offloaded to GPU, and add n_cpu_moe parameter Signed-off-by: Kira Selby --- llama_cpp/llama.py | 28 ++++++++++++++++++++++++++++ llama_cpp/llama_cpp.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 71d94ebd8..d3fb379f8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -70,6 +70,7 @@ def __init__( use_mmap: bool = True, use_mlock: bool = False, kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None, + n_cpu_moe: Optional[int] = None, # Context Params seed: int = llama_cpp.LLAMA_DEFAULT_SEED, n_ctx: int = 512, @@ -155,6 +156,7 @@ def __init__( use_mmap: Use mmap if possible. use_mlock: Force the system to keep the model in RAM. kv_overrides: Key-value overrides for the model. + n_cpu_moe: Number of MoE (Mixture of Experts) layers to keep on CPU. If None, MoE layers follow normal GPU/CPU distribution. seed: RNG seed, -1 for random n_ctx: Text context, 0 = from model n_batch: Prompt processing maximum batch size @@ -245,6 +247,32 @@ def __init__( self.model_params.use_mmap = use_mmap if lora_path is None else False self.model_params.use_mlock = use_mlock + # Handle n_cpu_moe parameter - configure tensor buffer overrides for MoE layers on CPU + self._tensor_buft_overrides = None + if n_cpu_moe is not None: + if n_cpu_moe < 0: + raise ValueError("n_cpu_moe must be non-negative") + if n_cpu_moe > 0: + # Create tensor buffer overrides for the first n_cpu_moe MoE layers + override_count = n_cpu_moe + 1 # +1 for null terminator + self._tensor_buft_overrides = (llama_cpp.llama_model_tensor_buft_override * override_count)() + + # Get CPU buffer type + cpu_buft = llama_cpp.ggml_backend_cpu_buffer_type() + + # Configure overrides for each layer + for i in range(n_cpu_moe): + pattern = f"blk.{i}.ffn_(up|down|gate)_exps".encode('utf-8') + self._tensor_buft_overrides[i].pattern = pattern + self._tensor_buft_overrides[i].buft = cpu_buft + + # Null terminator + self._tensor_buft_overrides[n_cpu_moe].pattern = None + self._tensor_buft_overrides[n_cpu_moe].buft = None + + # Set the overrides in model params + self.model_params.tensor_buft_overrides = self._tensor_buft_overrides + # kv_overrides is the original python dict self.kv_overrides = kv_overrides if kv_overrides is not None: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 711d42a6a..8ff2ea4f9 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -659,6 +659,22 @@ class llama_model_kv_override(ctypes.Structure): # const char * pattern; # ggml_backend_buffer_type_t buft; # }; +class llama_model_tensor_buft_override(ctypes.Structure): + """Buffer type override for model tensors + + Attributes: + pattern (ctypes.c_char_p): regex pattern to match tensor names + buft (ctypes.c_void_p): buffer type to use for matching tensors + """ + + if TYPE_CHECKING: + pattern: bytes + buft: ctypes.c_void_p + + _fields_ = [ + ("pattern", ctypes.c_char_p), + ("buft", ctypes.c_void_p), + ] # struct llama_model_params { @@ -732,7 +748,7 @@ class llama_model_params(ctypes.Structure): _fields_ = [ ("devices", ctypes.c_void_p), # NOTE: unnused - ("tensor_buft_overrides", ctypes.c_void_p), # NOTE: unused + ("tensor_buft_overrides", ctypes.POINTER(llama_model_tensor_buft_override)), ("n_gpu_layers", ctypes.c_int32), ("split_mode", ctypes.c_int), ("main_gpu", ctypes.c_int32), @@ -4372,3 +4388,15 @@ def llama_opt_epoch( /, ): ... + + +# // GGML backend buffer types +# GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void); +@ctypes_function( + "ggml_backend_cpu_buffer_type", + [], + ctypes.c_void_p, +) +def ggml_backend_cpu_buffer_type() -> ctypes.c_void_p: + """Get the CPU backend buffer type""" + ...