diff --git a/Cargo.lock b/Cargo.lock index b17cae6440..998d60d1d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3769,8 +3769,10 @@ dependencies = [ "impl-trait-for-tuples", "indexmap 2.9.0", "indicatif", + "inkwell", "inventory", "itertools 0.14.0", + "jit", "libc", "metrics", "mimalloc-rust-sys", @@ -6311,6 +6313,30 @@ dependencies = [ "str_stack", ] +[[package]] +name = "inkwell" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39457e8611219cf690f862a470575f5c06862910d03ea3c3b187ad7abc44b4e2" +dependencies = [ + "inkwell_internals", + "libc", + "llvm-sys", + "once_cell", + "thiserror 2.0.12", +] + +[[package]] +name = "inkwell_internals" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad9a7dd586b00f2b20e0b9ae3c6faa351fbfd56d15d63bbce35b13bece682eda" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "inout" version = "0.1.4" @@ -6457,6 +6483,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "jit" +version = "0.1.0" +dependencies = [ + "anyhow", + "dbsp", + "inkwell", + "thiserror 1.0.69", +] + [[package]] name = "jobserver" version = "0.1.33" @@ -6720,6 +6756,20 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" +[[package]] +name = "llvm-sys" +version = "211.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "108b3ad2b2eaf2a561fc74196273b20e3436e4a688b8b44e250d83974dc1b2e2" +dependencies = [ + "anyhow", + "cc", + "lazy_static", + "libc", + "regex-lite", + "semver", +] + [[package]] name = "local-channel" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index 1ffc7bf85a..a8f37fa773 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ members = [ "crates/storage", "crates/rest-api", "crates/ir", - "crates/fxp", + "crates/fxp", "crates/jit", ] exclude = [ "sql-to-dbsp-compiler/temp", diff --git a/crates/dbsp/Cargo.toml b/crates/dbsp/Cargo.toml index 145684b542..874b6a7660 100644 --- a/crates/dbsp/Cargo.toml +++ b/crates/dbsp/Cargo.toml @@ -96,6 +96,9 @@ feldera-ir = { workspace = true } smallvec = { workspace = true } async-stream = { workspace = true } futures-util = { workspace = true } +jit = { path = "../jit" } +inkwell = { version = "0.7.1", default-features = false, features = ["llvm21-1"] } + [dev-dependencies] rand = { workspace = true } diff --git a/crates/dbsp/src/jit/mod.rs b/crates/dbsp/src/jit/mod.rs new file mode 100644 index 0000000000..72545c0fe2 --- /dev/null +++ b/crates/dbsp/src/jit/mod.rs @@ -0,0 +1,70 @@ +//! JIT operator for DBSP circuits. + +use crate::{ + circuit::{ + metadata::OperatorLocation, + operator_traits::{Operator, UnaryOperator}, + Circuit, Stream, + }, + Scope, +}; +use jit::{JitFunction, RawJitBatch}; +use std::{borrow::Cow, panic::Location}; + +/// Operator that forwards a `RawJitBatch` into JIT compiled code. +pub struct JitInvokeOperator { + func: JitFunction, + name: Cow<'static, str>, + location: &'static Location<'static>, +} + +impl JitInvokeOperator { + fn new( + func: JitFunction, + name: Cow<'static, str>, + location: &'static Location<'static>, + ) -> Self { + Self { + func, + name, + location, + } + } +} + +impl Operator for JitInvokeOperator { + fn name(&self) -> Cow<'static, str> { + self.name.clone() + } + + fn location(&self) -> OperatorLocation { + Some(self.location) + } + + fn fixedpoint(&self, _scope: Scope) -> bool { + true + } +} + +impl UnaryOperator for JitInvokeOperator { + async fn eval(&mut self, input: &RawJitBatch) -> RawJitBatch { + self.func.invoke(*input) + } +} + +impl Stream +where + C: Circuit, +{ + /// Insert a JIT invocation operator that mutates the pointer batch in place. + #[track_caller] + pub fn invoke_jit(&self, name: N, func: JitFunction) -> Stream + where + N: Into>, + { + self.circuit().add_unary_operator( + JitInvokeOperator::new(func, name.into(), Location::caller()), + self, + ) + } +} \ No newline at end of file diff --git a/crates/dbsp/src/lib.rs b/crates/dbsp/src/lib.rs index 6687a23525..21c3767dc0 100644 --- a/crates/dbsp/src/lib.rs +++ b/crates/dbsp/src/lib.rs @@ -73,6 +73,8 @@ pub mod typed_batch; pub mod circuit; pub mod algebra; pub mod ir; +pub mod jit; + pub mod mimalloc; pub mod monitor; pub mod operator; @@ -99,6 +101,7 @@ pub use circuit::{ ChildCircuit, Circuit, CircuitHandle, DBSPHandle, NestedCircuit, RootCircuit, Runtime, RuntimeError, SchedulerError, Stream, WeakRuntime, }; + #[cfg(not(feature = "backend-mode"))] pub use operator::FilterMap; pub use operator::{ diff --git a/crates/jit/Cargo.toml b/crates/jit/Cargo.toml new file mode 100644 index 0000000000..7b0042acc2 --- /dev/null +++ b/crates/jit/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "jit" +version = "0.1.0" +edition = "2021" + +[lib] + +[dependencies] +inkwell = { version = "0.7.1", default-features = false, features = ["llvm21-1"] } +thiserror = "1.0" +anyhow = { workspace = true } + +[dev-dependencies] +dbsp = { path = "../dbsp" } + +[[example]] +name = "jit" diff --git a/crates/jit/dataflow.json b/crates/jit/dataflow.json new file mode 100644 index 0000000000..6dd77a59a8 --- /dev/null +++ b/crates/jit/dataflow.json @@ -0,0 +1,62 @@ +{ + "calcite_plan": { + "error_view": { + "rels": [ + { + "id": 0, + "relOp": "LogicalTableScan", + "table": [ + "schema", + "feldera_error_table" + ], + "inputs": [] + } + ] + } + }, + "mir": { + "s0": { + "operation": "constant", + "inputs": [], + "calcite": { + "partial": 0 + }, + "positions": [ + {"start_line_number":4,"start_column":1,"end_line_number":4,"end_column":59} + ], + "persistent_id": "dabdc0517fb639de8ebd480cb4350e3b2054b584a8c9c42515f268f99294f72c" + }, "s1": { + "operation": "inspect", + "inputs": [ + { "node": "s0", "output": 0 } + ], + "view": "error_view", + "calcite": { + "final": 0 + }, + "positions": [ + {"start_line_number":4,"start_column":1,"end_line_number":4,"end_column":59}, + {"start_line_number":4,"start_column":1,"end_line_number":4,"end_column":59} + ], + "persistent_id": "42bb786d35654abe048f3d4cedb08308c085b652304d4902765d81d8afda529c" + }, "s2": { + "operation": "source_multiset", + "inputs": [], + "table": "test", + "calcite": { + "and": [ + ] + }, + "positions": [ + {"start_line_number":1,"start_column":14,"end_line_number":1,"end_column":17}, + {"start_line_number":1,"start_column":14,"end_line_number":1,"end_column":17} + ], + "persistent_id": "e755d32bfc7ee015cbeb0a99e513b8098a82b6cc83f7d523a850283dac102e61" + } + }, + "sources": [ + "CREATE TABLE test(", + " id INT", + ");" + ] +} \ No newline at end of file diff --git a/crates/jit/examples/jit.rs b/crates/jit/examples/jit.rs new file mode 100644 index 0000000000..3c432ab0b6 --- /dev/null +++ b/crates/jit/examples/jit.rs @@ -0,0 +1,124 @@ +use dbsp::{circuit::Stream, operator::InputHandle, RootCircuit, Runtime}; +use jit::{JitFunction, LlvmCircuitJit, RawJitBatch}; +use std::{ffi::c_void, mem}; + +/// A C-compatible struct to pass a stream and a handle across the FFI boundary. +#[repr(C)] +struct StreamHandlePair { + stream: *mut c_void, + handle: *mut c_void, +} + +/// A C-compatible struct to pass a stream and a JIT function across the FFI boundary. +#[repr(C)] +struct StreamAndJitFunction { + stream: *mut c_void, + jit_function: *mut c_void, +} + +/// An `extern "C"` function that wraps the `add_input_stream` call. +/// +/// # Safety +/// +/// This function is highly unsafe. It casts the `circuit` pointer and leaks the +/// created stream and handle to be passed back to the JIT-compiled caller. +/// The caller is responsible for reconstituting and managing the memory of +/// these objects. +#[no_mangle] +unsafe extern "C" fn add_input_stream_helper(circuit: *mut c_void) -> StreamHandlePair { + let circuit = &mut *(circuit as *mut RootCircuit); + let (stream, handle): (Stream, InputHandle) = + circuit.add_input_stream::(); + + StreamHandlePair { + stream: Box::into_raw(Box::new(stream)).cast(), + handle: Box::into_raw(Box::new(handle)).cast(), + } +} + +/// An `extern "C"` function that wraps the `invoke_jit` call. +/// +/// # Safety +/// +/// This function is highly unsafe. It casts pointers and leaks the created stream. +#[no_mangle] +unsafe extern "C" fn invoke_jit_helper(input: *mut StreamAndJitFunction) -> *mut c_void { + let stream = &*(*(input)).stream.cast::>(); + let jit_function = &*(*(input)).jit_function.cast::(); + let output_stream = stream.invoke_jit("llvm-add", jit_function.clone()); + Box::into_raw(Box::new(output_stream)).cast() +} + +fn main() -> anyhow::Result<()> { + // ---- Part 1: JIT-compile a simple data-parallel pipeline ---- + let jit = LlvmCircuitJit::new("jit_demo"); + let data_pipeline = jit.compile_add_pipeline(&[1, 2, 3])?; + + // ---- Part 2: JIT-compile a circuit construction function ---- + let context = jit.context(); + let ptr_type = context.ptr_type(inkwell::AddressSpace::default()); + let pair_struct_type = context.struct_type(&[ptr_type.into(), ptr_type.into()], false); + + let circuit_builder = unsafe { + jit.compile_circuit_builder( + add_input_stream_helper as usize, + invoke_jit_helper as usize, + Some(pair_struct_type), + )? + }; + + type CircuitBuilderFn = unsafe extern "C" fn(*mut c_void, *mut c_void) -> StreamHandlePair; + let circuit_builder_fn: CircuitBuilderFn = unsafe { mem::transmute(circuit_builder.raw()) }; + + // ---- Part 3: Build and run the DBSP circuit ---- + let (mut dbsp, (input_handle, scratch_handle)) = Runtime::init_circuit(1, move |circuit| { + // Use the JIT-compiled function to create the input stream and attach the JIT-compiled + // pipeline. + let circuit_ptr = circuit as *mut _ as *mut c_void; + let data_pipeline_ptr = &data_pipeline as *const _ as *mut c_void; + let pair = unsafe { circuit_builder_fn(circuit_ptr, data_pipeline_ptr) }; + + // Reconstitute the stream and handle. This is the counterpart to the memory + // leak in the helper function. + let stream: Box> = + unsafe { Box::from_raw(pair.stream.cast()) }; + let handle: Box> = unsafe { Box::from_raw(pair.handle.cast()) }; + + // Now, use the reconstituted stream to continue building the circuit. + stream.inspect(|batch| unsafe { + let ptr = batch.as_ptr() as *const i32; + if !ptr.is_null() { + println!("JIT produced value: {}", *ptr); + } + }); + + // Create another input for testing. + let (scratch_stream, scratch_handle) = circuit.add_input_stream::(); + scratch_stream.inspect(|val| println!("Scratch input received: {val}")); + + Ok((*handle, scratch_handle)) + })?; + + // ---- Part 4: Execute the circuit ---- + let mut first = 10_i32; + let mut second = -5_i32; + let batches = [ + unsafe { RawJitBatch::from_raw((&mut first as *mut i32).cast()) }, + unsafe { RawJitBatch::from_raw((&mut second as *mut i32).cast()) }, + ]; + + for batch in batches { + input_handle.set_for_all(batch); + dbsp.transaction().expect("circuit transaction succeeds"); + } + println!("Host view after JIT execution: first = {first}, second = {second}"); + + // Demonstrate that the circuit is still live. + input_handle.set_for_all(batches[0]); + scratch_handle.set_for_all(100); + dbsp.transaction().expect("circuit transaction succeeds"); + + dbsp.kill().unwrap(); + + Ok(()) +} diff --git a/crates/jit/output_schema.json b/crates/jit/output_schema.json new file mode 100644 index 0000000000..62262f62a6 --- /dev/null +++ b/crates/jit/output_schema.json @@ -0,0 +1,50 @@ +{ + "inputs" : [ { + "name" : "test", + "case_sensitive" : false, + "fields" : [ { + "name" : "id", + "case_sensitive" : false, + "columntype" : { + "nullable" : true, + "type" : "INTEGER" + }, + "unused" : false + } ], + "materialized" : false, + "foreign_keys" : [ ] + } ], + "outputs" : [ { + "name" : "error_view", + "case_sensitive" : false, + "fields" : [ { + "name" : "table_or_view_name", + "case_sensitive" : false, + "columntype" : { + "nullable" : false, + "precision" : -1, + "type" : "VARCHAR" + }, + "unused" : false + }, { + "name" : "message", + "case_sensitive" : false, + "columntype" : { + "nullable" : false, + "precision" : -1, + "type" : "VARCHAR" + }, + "unused" : false + }, { + "name" : "metadata", + "case_sensitive" : false, + "columntype" : { + "nullable" : false, + "precision" : -1, + "type" : "VARCHAR" + }, + "unused" : false + } ], + "materialized" : false + } ] +} diff --git a/crates/jit/sql-to-dbsp.bash b/crates/jit/sql-to-dbsp.bash new file mode 100644 index 0000000000..e31fdae02a --- /dev/null +++ b/crates/jit/sql-to-dbsp.bash @@ -0,0 +1 @@ +java -jar ../../sql-to-dbsp-compiler/SQL-compiler/target/sql2dbsp-jar-with-dependencies.jar test.sql -js output_schema.json -o out.rs --dataflow dataflow.json -i -je --alltables --ignoreOrder diff --git a/crates/jit/src/lib.rs b/crates/jit/src/lib.rs new file mode 100644 index 0000000000..510df8c99e --- /dev/null +++ b/crates/jit/src/lib.rs @@ -0,0 +1,485 @@ +//! Prototype LLVM-backed JIT helpers. + +use std::{ + ffi::c_void, + fmt, ptr, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use inkwell::{ + context::Context, + execution_engine::ExecutionEngine, + module::Module, + types::{BasicMetadataTypeEnum, BasicType}, + AddressSpace, OptimizationLevel, +}; +use thiserror::Error; + +/// Signature of the generated JIT functions. +pub type JitFn = unsafe extern "C" fn(*mut c_void) -> *mut c_void; + +/// A minimal raw batch wrapper. The pointer is opaque to DBSP; JIT'd code owns +/// the layout and interpretation. +#[derive(Copy, Clone, Eq, PartialEq)] +#[repr(transparent)] +pub struct RawJitBatch { + ptr: *mut c_void, +} + +impl RawJitBatch { + /// Create a batch that points to `ptr`. + /// + /// # Safety + /// + /// `ptr` must remain valid for as long as the batch is used by the circuit. + pub const unsafe fn from_raw(ptr: *mut c_void) -> Self { + Self { ptr } + } + + /// A null pointer batch (the default). + pub const fn null() -> Self { + Self { + ptr: ptr::null_mut(), + } + } + + /// Expose the inner pointer. + pub const fn as_ptr(self) -> *mut c_void { + self.ptr + } +} + +unsafe impl Send for RawJitBatch {} +unsafe impl Sync for RawJitBatch {} + +impl Default for RawJitBatch { + fn default() -> Self { + Self::null() + } +} + +impl fmt::Debug for RawJitBatch { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("RawJitBatch") + .field(&(self.ptr as usize)) + .finish() + } +} + +/// Errors surfaced while building or running JIT code. +#[derive(Debug, Error)] +pub enum JitError { + #[error("LLVM module verification failed: {0}")] + Verify(String), + #[error("LLVM instruction builder failed: {0}")] + Build(String), + #[error("failed to create execution engine: {0}")] + Engine(String), + #[error("failed to find symbol `{symbol}`: {error}")] + Lookup { symbol: String, error: String }, +} + +/// Lightweight owner for LLVM artifacts created per compiled function. +#[doc(hidden)] +pub struct SharedJitModule { + module: *mut Module<'static>, + execution_engine: *mut ExecutionEngine<'static>, +} + +impl SharedJitModule { + fn new(module: Module<'static>, execution_engine: ExecutionEngine<'static>) -> Self { + Self { + module: Box::into_raw(Box::new(module)), + execution_engine: Box::into_raw(Box::new(execution_engine)), + } + } +} + +unsafe impl Send for SharedJitModule {} +unsafe impl Sync for SharedJitModule {} + +impl Drop for SharedJitModule { + fn drop(&mut self) { + unsafe { + if !self.execution_engine.is_null() { + drop(Box::from_raw(self.execution_engine)); + self.execution_engine = ptr::null_mut(); + } + if !self.module.is_null() { + drop(Box::from_raw(self.module)); + self.module = ptr::null_mut(); + } + } + } +} + +/// Handle to a compiled JIT function. +#[derive(Clone)] +pub struct JitFunction { + symbol: Arc, + func: JitFn, + _keepalive: Arc, +} + +impl JitFunction { + /// This is not meant to be a public API, but `LlvmCircuitJit` needs it. + #[doc(hidden)] + pub fn new(symbol: Arc, func: JitFn, keepalive: SharedJitModule) -> Self { + Self { + symbol, + func, + _keepalive: Arc::new(keepalive), + } + } + + /// Human-readable symbol of the compiled function. + pub fn symbol(&self) -> &str { + &self.symbol + } + + /// Obtain the raw function pointer. + pub fn raw(&self) -> JitFn { + self.func + } + + /// Invoke the compiled function on `batch`. + pub fn invoke(&self, batch: RawJitBatch) -> RawJitBatch { + let ptr = unsafe { (self.func)(batch.as_ptr()) }; + unsafe { RawJitBatch::from_raw(ptr) } + } +} + +/// Microscopic LLVM JIT wrapper that fabricates pointer-manipulating kernels. +pub struct LlvmCircuitJit { + context: &'static Context, + module_prefix: Arc, + counter: AtomicUsize, + optimization: OptimizationLevel, +} + +impl LlvmCircuitJit { + /// Create a new JIT with modules prefixed by `module_prefix`. + pub fn new(module_prefix: impl Into) -> Self { + let ctx = Box::leak(Box::new(Context::create())); + Self { + context: ctx, + module_prefix: module_prefix.into().into(), + counter: AtomicUsize::new(0), + optimization: OptimizationLevel::None, + } + } + + pub fn context(&self) -> &'static Context { + self.context + } + + /// Compile a function that interprets the raw pointer as an `i32` buffer and + /// adds all `increments` to the pointed-to value in sequence. + pub fn compile_add_pipeline(&self, increments: &[i32]) -> Result { + let symbol_name = format!( + "{}_fn_{}", + self.module_prefix, + self.counter.fetch_add(1, Ordering::Relaxed) + ); + let module_name = format!("{}_module", symbol_name); + let module = self.context.create_module(&module_name); + let builder = self.context.create_builder(); + + let ptr_ty = self.context.ptr_type(AddressSpace::default()); + let fn_ty = ptr_ty.fn_type(&[ptr_ty.into()], false); + let function = module.add_function(&symbol_name, fn_ty, None); + let entry = self.context.append_basic_block(function, "entry"); + builder.position_at_end(entry); + + let raw_arg = function.get_first_param().unwrap().into_pointer_value(); + let i32_ty = self.context.i32_type(); + let typed_ptr = builder + .build_bit_cast( + raw_arg, + self.context.ptr_type(AddressSpace::default()), + "batch_ptr", + ) + .map_err(|e| JitError::Build(e.to_string()))? + .into_pointer_value(); + + let mut acc = builder + .build_load(i32_ty, typed_ptr, "current") + .map_err(|e| JitError::Build(e.to_string()))? + .into_int_value(); + + for (idx, increment) in increments.iter().copied().enumerate() { + let llvm_inc = i32_ty.const_int(increment as u64, true); + acc = builder + .build_int_add(acc, llvm_inc, &format!("add_{idx}")) + .map_err(|e| JitError::Build(e.to_string()))?; + } + + builder + .build_store(typed_ptr, acc) + .map_err(|e| JitError::Build(e.to_string()))?; + builder + .build_return(Some(&raw_arg)) + .map_err(|e| JitError::Build(e.to_string()))?; + + module + .verify() + .map_err(|e| JitError::Verify(e.to_string()))?; + + let execution_engine = module + .create_jit_execution_engine(self.optimization) + .map_err(|e| JitError::Engine(e.to_string()))?; + + let address = execution_engine + .get_function_address(&symbol_name) + .map_err(|e| JitError::Lookup { + symbol: symbol_name.clone(), + error: e.to_string(), + })?; + + let func: JitFn = unsafe { std::mem::transmute(address) }; + Ok(JitFunction::new( + Arc::::from(symbol_name), + func, + SharedJitModule::new(module, execution_engine), + )) + } + + /// Compile a JIT function that wraps a call to an arbitrary function pointer. + /// + /// # Safety + /// + /// The caller must ensure that `target_fn_ptr` is a valid function pointer + /// that will remain valid for the lifetime of the returned `JitFunction`. + /// The signature of the `target_fn_ptr` must match what the generated LLVM + /// IR expects. This is a sharp tool for advanced FFI scenarios. + pub unsafe fn compile_ffi_call( + &self, + target_fn_ptr: usize, + param_types: &[inkwell::types::BasicTypeEnum<'static>], + return_type: inkwell::types::BasicTypeEnum<'static>, + return_struct_type: Option>, + ) -> Result { + let symbol_name = format!( + "{}_fn_{}", + self.module_prefix, + self.counter.fetch_add(1, Ordering::Relaxed) + ); + let module_name = format!("{}_module", symbol_name); + let module = self.context.create_module(&module_name); + let builder = self.context.create_builder(); + + let metadata_param_types: Vec = + param_types.iter().map(|&t| t.into()).collect(); + + // Create the JIT function's signature. + let fn_ty = if let Some(struct_type) = return_struct_type { + struct_type.fn_type(&metadata_param_types, false) + } else { + return_type.fn_type(&metadata_param_types, false) + }; + let function = module.add_function(&symbol_name, fn_ty, None); + let entry = self.context.append_basic_block(function, "entry"); + builder.position_at_end(entry); + + // Get the function pointer for the target C function. + let ptr_ty = self.context.ptr_type(AddressSpace::default()); + let target_fn_ptr_val = self + .context + .i64_type() + .const_int(target_fn_ptr as u64, false); + let target_fn = builder + .build_int_to_ptr(target_fn_ptr_val, ptr_ty, "target_fn_ptr") + .unwrap(); + + // Build the call instruction. + let params: Vec<_> = function + .get_param_iter() + .map(|param| param.into()) + .collect(); + let call = builder + .build_indirect_call(fn_ty, target_fn, ¶ms, "call") + .unwrap(); + + // Build the return instruction. A function returns void if `get_return_type()` is `None`. + if fn_ty.get_return_type().is_some() { + let return_value = call.try_as_basic_value().unwrap_basic(); + builder.build_return(Some(&return_value)).unwrap(); + } else { + builder.build_return(None).unwrap(); + } + + module + .verify() + .map_err(|e| JitError::Verify(e.to_string()))?; + + let execution_engine = module + .create_jit_execution_engine(self.optimization) + .map_err(|e| JitError::Engine(e.to_string()))?; + + let address = execution_engine + .get_function_address(&symbol_name) + .map_err(|e| JitError::Lookup { + symbol: symbol_name.clone(), + error: e.to_string(), + })?; + + let func: JitFn = std::mem::transmute(address); + Ok(JitFunction::new( + Arc::::from(symbol_name), + func, + SharedJitModule::new(module, execution_engine), + )) + } + + /// This is a temporary function to compile a circuit builder. + /// + /// # Safety + /// + /// The caller must ensure that `add_input_stream_fn_ptr` and `invoke_jit_fn_ptr` are valid + /// function pointers that will remain valid for the lifetime of the returned `JitFunction`. + /// The signatures of these function pointers must match what the generated LLVM IR expects. + pub unsafe fn compile_circuit_builder( + &self, + add_input_stream_fn_ptr: usize, + invoke_jit_fn_ptr: usize, + return_struct_type: Option>, + ) -> Result { + let symbol_name = format!( + "{}_fn_{}", + self.module_prefix, + self.counter.fetch_add(1, Ordering::Relaxed) + ); + let module_name = format!("{}_module", symbol_name); + let module = self.context.create_module(&module_name); + let builder = self.context.create_builder(); + + let ptr_type = self.context.ptr_type(AddressSpace::default()); + + // JIT function signature + let fn_type = return_struct_type.unwrap().fn_type( + &[ptr_type.into(), ptr_type.into()], + false, + ); + let function = module.add_function(&symbol_name, fn_type, None); + let entry = self.context.append_basic_block(function, "entry"); + builder.position_at_end(entry); + + // Call `add_input_stream_helper` + let add_input_stream_fn_type = + return_struct_type.unwrap().fn_type(&[ptr_type.into()], false); + let add_input_stream_fn_ptr_val = self + .context + .i64_type() + .const_int(add_input_stream_fn_ptr as u64, false); + let add_input_stream_fn = builder + .build_int_to_ptr( + add_input_stream_fn_ptr_val, + ptr_type, + "add_input_stream_fn_ptr", + ) + .unwrap(); + let circuit_arg = function.get_first_param().unwrap(); + let stream_handle_pair = builder + .build_indirect_call( + add_input_stream_fn_type, + add_input_stream_fn, + &[circuit_arg.into()], + "call_add_input_stream", + ) + .unwrap() + .try_as_basic_value() + .unwrap_basic() + .into_struct_value(); + + // Call `invoke_jit_helper` + let invoke_jit_fn_type = ptr_type.fn_type(&[ptr_type.into()], false); + let invoke_jit_fn_ptr_val = self + .context + .i64_type() + .const_int(invoke_jit_fn_ptr as u64, false); + let invoke_jit_fn = builder + .build_int_to_ptr(invoke_jit_fn_ptr_val, ptr_type, "invoke_jit_fn_ptr") + .unwrap(); + + let stream_ptr = builder + .build_extract_value(stream_handle_pair, 0, "stream_ptr") + .unwrap(); + let handle_ptr = builder + .build_extract_value(stream_handle_pair, 1, "handle_ptr") + .unwrap(); + let data_pipeline_ptr = function.get_nth_param(1).unwrap(); + + // Create `StreamAndJitFunction` struct + let stream_and_jit_function_type = + self.context.struct_type(&[ptr_type.into(), ptr_type.into()], false); + let stream_and_jit_function_ptr = builder + .build_alloca(stream_and_jit_function_type, "arg_struct") + .unwrap(); + let stream_field_ptr = builder + .build_struct_gep( + stream_and_jit_function_type, + stream_and_jit_function_ptr, + 0, + "stream_field_ptr", + ) + .unwrap(); + let jit_function_field_ptr = builder + .build_struct_gep( + stream_and_jit_function_type, + stream_and_jit_function_ptr, + 1, + "jit_function_field_ptr", + ) + .unwrap(); + builder.build_store(stream_field_ptr, stream_ptr).unwrap(); + builder + .build_store(jit_function_field_ptr, data_pipeline_ptr) + .unwrap(); + + let output_stream_ptr = builder + .build_indirect_call( + invoke_jit_fn_type, + invoke_jit_fn, + &[stream_and_jit_function_ptr.into()], + "call_invoke_jit", + ) + .unwrap() + .try_as_basic_value() + .unwrap_basic(); + + // Return the final stream and handle + let final_pair = return_struct_type.unwrap().const_zero(); + let final_pair = builder + .build_insert_value(final_pair, output_stream_ptr, 0, "final_pair") + .unwrap(); + let final_pair = builder + .build_insert_value(final_pair, handle_ptr, 1, "final_pair") + .unwrap(); + builder.build_return(Some(&final_pair)).unwrap(); + + module + .verify() + .map_err(|e| JitError::Verify(e.to_string()))?; + + let execution_engine = module + .create_jit_execution_engine(self.optimization) + .map_err(|e| JitError::Engine(e.to_string()))?; + + let address = execution_engine + .get_function_address(&symbol_name) + .map_err(|e| JitError::Lookup { + symbol: symbol_name.clone(), + error: e.to_string(), + })?; + + let func: JitFn = std::mem::transmute(address); + Ok(JitFunction::new( + Arc::::from(symbol_name), + func, + SharedJitModule::new(module, execution_engine), + )) + } +} diff --git a/crates/jit/stubs.rs b/crates/jit/stubs.rs new file mode 100644 index 0000000000..bc19d6c9e4 --- /dev/null +++ b/crates/jit/stubs.rs @@ -0,0 +1,10 @@ +// Compiler-generated file. +// This file contains stubs for user-defined functions declared in the SQL program. +// Each stub defines a function prototype that must be implemented in `udf.rs`. +// Copy these stubs to `udf.rs`, replacing their bodies with the actual UDF implementation. +// See detailed documentation in https://docs.feldera.com/sql/udf. + +#![allow(non_snake_case)] + +use feldera_sqllib::*; +use crate::*; diff --git a/crates/jit/test.sql b/crates/jit/test.sql new file mode 100644 index 0000000000..6aee340f2b --- /dev/null +++ b/crates/jit/test.sql @@ -0,0 +1,3 @@ +CREATE TABLE test( + id INT +);