From 0a27e76bc164e0c20a0e8a392a8de61ec56f7b70 Mon Sep 17 00:00:00 2001 From: Andreas Marek Date: Sun, 30 Mar 2025 18:21:55 +1000 Subject: [PATCH 1/3] introducing of a special CompletableFuture enabling scheduling of nested/chained DataLoader calls --- .../execution/DataLoaderDispatchStrategy.java | 4 +- .../java/graphql/execution/Execution.java | 6 +- .../graphql/execution/ExecutionStrategy.java | 2 +- .../dataloader/DataLoaderCF.java | 99 ++++++ .../PerLevelDataLoaderDispatchStrategy.java | 121 +++++++- ...spatchStrategyWithDeferAlwaysDispatch.java | 4 +- .../groovy/graphql/DataLoaderCFTest.groovy | 287 ++++++++++++++++++ 7 files changed, 516 insertions(+), 7 deletions(-) create mode 100644 src/main/java/graphql/execution/instrumentation/dataloader/DataLoaderCF.java create mode 100644 src/test/groovy/graphql/DataLoaderCFTest.groovy diff --git a/src/main/java/graphql/execution/DataLoaderDispatchStrategy.java b/src/main/java/graphql/execution/DataLoaderDispatchStrategy.java index 5101ae3a56..bbc0f18640 100644 --- a/src/main/java/graphql/execution/DataLoaderDispatchStrategy.java +++ b/src/main/java/graphql/execution/DataLoaderDispatchStrategy.java @@ -2,8 +2,10 @@ import graphql.Internal; import graphql.schema.DataFetcher; +import graphql.schema.DataFetchingEnvironment; import java.util.List; +import java.util.function.Supplier; @Internal public interface DataLoaderDispatchStrategy { @@ -44,7 +46,7 @@ default void executeObjectOnFieldValuesException(Throwable t, ExecutionStrategyP default void fieldFetched(ExecutionContext executionContext, ExecutionStrategyParameters executionStrategyParameters, DataFetcher dataFetcher, - Object fetchedValue) { + Object fetchedValue, Supplier dataFetchingEnvironment) { } diff --git a/src/main/java/graphql/execution/Execution.java b/src/main/java/graphql/execution/Execution.java index 634837f987..6911c7087a 100644 --- a/src/main/java/graphql/execution/Execution.java +++ b/src/main/java/graphql/execution/Execution.java @@ -29,8 +29,8 @@ import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLSchema; import graphql.schema.impl.SchemaUtil; -import org.jspecify.annotations.NonNull; import graphql.util.FpKit; +import org.jspecify.annotations.NonNull; import org.reactivestreams.Publisher; import java.util.Collections; @@ -58,6 +58,8 @@ public class Execution { private final ValueUnboxer valueUnboxer; private final boolean doNotAutomaticallyDispatchDataLoader; + public static final String EXECUTION_CONTEXT_KEY = "__GraphQL_Java_ExecutionContext"; + public Execution(ExecutionStrategy queryStrategy, ExecutionStrategy mutationStrategy, ExecutionStrategy subscriptionStrategy, @@ -114,6 +116,8 @@ public CompletableFuture execute(Document document, GraphQLSche .build(); executionContext.getGraphQLContext().put(ResultNodesInfo.RESULT_NODES_INFO, executionContext.getResultNodesInfo()); + executionContext.getGraphQLContext().put(EXECUTION_CONTEXT_KEY, executionContext); + InstrumentationExecutionParameters parameters = new InstrumentationExecutionParameters( executionInput, graphQLSchema diff --git a/src/main/java/graphql/execution/ExecutionStrategy.java b/src/main/java/graphql/execution/ExecutionStrategy.java index 3b45786533..d647f650d1 100644 --- a/src/main/java/graphql/execution/ExecutionStrategy.java +++ b/src/main/java/graphql/execution/ExecutionStrategy.java @@ -496,7 +496,7 @@ private Object fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext exec dataFetcher = instrumentation.instrumentDataFetcher(dataFetcher, instrumentationFieldFetchParams, executionContext.getInstrumentationState()); dataFetcher = executionContext.getDataLoaderDispatcherStrategy().modifyDataFetcher(dataFetcher); Object fetchedObject = invokeDataFetcher(executionContext, parameters, fieldDef, dataFetchingEnvironment, dataFetcher); - executionContext.getDataLoaderDispatcherStrategy().fieldFetched(executionContext, parameters, dataFetcher, fetchedObject); + executionContext.getDataLoaderDispatcherStrategy().fieldFetched(executionContext, parameters, dataFetcher, fetchedObject, dataFetchingEnvironment); fetchCtx.onDispatched(); fetchCtx.onFetchedValue(fetchedObject); // if it's a subscription, leave any reactive objects alone diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/DataLoaderCF.java b/src/main/java/graphql/execution/instrumentation/dataloader/DataLoaderCF.java new file mode 100644 index 0000000000..e72090c7e0 --- /dev/null +++ b/src/main/java/graphql/execution/instrumentation/dataloader/DataLoaderCF.java @@ -0,0 +1,99 @@ +package graphql.execution.instrumentation.dataloader; + +import graphql.ExperimentalApi; +import graphql.Internal; +import graphql.execution.DataLoaderDispatchStrategy; +import graphql.execution.ExecutionContext; +import graphql.schema.DataFetchingEnvironment; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.function.Supplier; + +import static graphql.execution.Execution.EXECUTION_CONTEXT_KEY; + +@Internal +public class DataLoaderCF extends CompletableFuture { + final DataFetchingEnvironment dfe; + final String dataLoaderName; + final Object key; + final CompletableFuture dataLoaderCF; + + volatile CountDownLatch latch; + + public DataLoaderCF(DataFetchingEnvironment dfe, String dataLoaderName, Object key) { + this.dfe = dfe; + this.dataLoaderName = dataLoaderName; + this.key = key; + if (dataLoaderName != null) { + dataLoaderCF = dfe.getDataLoaderRegistry().getDataLoader(dataLoaderName).load(key); + dataLoaderCF.whenComplete((value, throwable) -> { + System.out.println("underlying DataLoader completed"); + if (throwable != null) { + completeExceptionally(throwable); + } else { + complete((T) value); + } + // post completion hook + if (latch != null) { + latch.countDown(); + } + }); + } else { + dataLoaderCF = null; + } + } + + DataLoaderCF() { + this.dfe = null; + this.dataLoaderName = null; + this.key = null; + dataLoaderCF = null; + } + + @Override + public CompletableFuture newIncompleteFuture() { + return new DataLoaderCF<>(); + } + + public static boolean isDataLoaderCF(Object object) { + return object instanceof DataLoaderCF; + } + + @ExperimentalApi + public static CompletableFuture newDataLoaderCF(DataFetchingEnvironment dfe, String dataLoaderName, Object key) { + DataLoaderCF result = new DataLoaderCF<>(dfe, dataLoaderName, key); + ExecutionContext executionContext = dfe.getGraphQlContext().get(EXECUTION_CONTEXT_KEY); + DataLoaderDispatchStrategy dataLoaderDispatcherStrategy = executionContext.getDataLoaderDispatcherStrategy(); + if (dataLoaderDispatcherStrategy instanceof PerLevelDataLoaderDispatchStrategy) { + ((PerLevelDataLoaderDispatchStrategy) dataLoaderDispatcherStrategy).newDataLoaderCF(result); + } + return result; + } + + + @ExperimentalApi + public static CompletableFuture supplyAsyncDataLoaderCF(DataFetchingEnvironment env, Supplier supplier) { + DataLoaderCF d = new DataLoaderCF<>(env, null, null); + d.defaultExecutor().execute(() -> { + d.complete(supplier.get()); + }); + return d; + + } + + @ExperimentalApi + public static CompletableFuture wrap(DataFetchingEnvironment env, CompletableFuture completableFuture) { + DataLoaderCF d = new DataLoaderCF<>(env, null, null); + completableFuture.whenComplete((u, ex) -> { + if (ex != null) { + d.completeExceptionally(ex); + } else { + d.complete(u); + } + }); + return d; + } + + +} diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java index 0d1903eaab..deb00a0b79 100644 --- a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java +++ b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java @@ -7,12 +7,23 @@ import graphql.execution.ExecutionStrategyParameters; import graphql.execution.FieldValueInfo; import graphql.schema.DataFetcher; +import graphql.schema.DataFetchingEnvironment; import graphql.util.LockKit; import org.dataloader.DataLoaderRegistry; +import java.util.ArrayList; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; @Internal public class PerLevelDataLoaderDispatchStrategy implements DataLoaderDispatchStrategy { @@ -20,6 +31,9 @@ public class PerLevelDataLoaderDispatchStrategy implements DataLoaderDispatchStr private final CallStack callStack; private final ExecutionContext executionContext; + static final ScheduledExecutorService isolatedDLCFBatchWindowScheduler = Executors.newSingleThreadScheduledExecutor(); + static final int BATCH_WINDOW_NANO_SECONDS = 500_000; + private static class CallStack { @@ -34,10 +48,27 @@ private static class CallStack { private final Set dispatchedLevels = new LinkedHashSet<>(); + // fields only relevant when a DataLoaderCF is involved + private final List> allDataLoaderCF = new CopyOnWriteArrayList<>(); + //TODO: maybe this should be cleaned up once the CF returned by these fields are completed + // otherwise this will stick around until the whole request is finished + private final Set fieldsFinishedDispatching = ConcurrentHashMap.newKeySet(); + private final Map> levelToDFEWithDataLoaderCF = new ConcurrentHashMap<>(); + + private final Set batchWindowOfIsolatedDfeToDispatch = ConcurrentHashMap.newKeySet(); + + private boolean batchWindowOpen = false; + + public CallStack() { expectedExecuteObjectCallsPerLevel.set(1, 1); } + public void addDataLoaderDFE(int level, DataFetchingEnvironment dfe) { + levelToDFEWithDataLoaderCF.computeIfAbsent(level, k -> new LinkedHashSet<>()).add(dfe); + } + + void increaseExpectedFetchCount(int level, int count) { expectedFetchCountPerLevel.increment(level, count); } @@ -234,9 +265,13 @@ private int getObjectCountForList(List fieldValueInfos) { public void fieldFetched(ExecutionContext executionContext, ExecutionStrategyParameters executionStrategyParameters, DataFetcher dataFetcher, - Object fetchedValue) { + Object fetchedValue, + Supplier dataFetchingEnvironment) { int level = executionStrategyParameters.getPath().getLevel(); boolean dispatchNeeded = callStack.lock.callLocked(() -> { + if (DataLoaderCF.isDataLoaderCF(fetchedValue)) { + callStack.addDataLoaderDFE(level, dataFetchingEnvironment.get()); + } callStack.increaseFetchCount(level); return dispatchIfNeeded(level); }); @@ -275,9 +310,89 @@ private boolean levelReady(int level) { } void dispatch(int level) { - DataLoaderRegistry dataLoaderRegistry = executionContext.getDataLoaderRegistry(); - dataLoaderRegistry.dispatchAll(); + if (callStack.levelToDFEWithDataLoaderCF.size() > 0) { + dispatchDLCFImpl(callStack.levelToDFEWithDataLoaderCF.get(level)); + } else { + DataLoaderRegistry dataLoaderRegistry = executionContext.getDataLoaderRegistry(); + dataLoaderRegistry.dispatchAll(); + } } + + public void dispatchDLCFImpl(Set dfeToDispatchSet) { + + // filter out all DataLoaderCFS that are matching the fields we want to dispatch + List> relevantDataLoaderCFs = new ArrayList<>(); + for (DataLoaderCF dataLoaderCF : callStack.allDataLoaderCF) { + if (dfeToDispatchSet.contains(dataLoaderCF.dfe)) { + relevantDataLoaderCFs.add(dataLoaderCF); + } + } + // we are cleaning up the list of all DataLoadersCFs + callStack.allDataLoaderCF.removeAll(relevantDataLoaderCFs); + + // means we are all done dispatching the fields + if (relevantDataLoaderCFs.size() == 0) { + callStack.fieldsFinishedDispatching.addAll(dfeToDispatchSet); + return; + } + // we are dispatching all data loaders and waiting for all dataLoaderCFs to complete + // and to finish their sync actions + CountDownLatch countDownLatch = new CountDownLatch(relevantDataLoaderCFs.size()); + for (DataLoaderCF dlCF : relevantDataLoaderCFs) { + dlCF.latch = countDownLatch; + } + // TODO: this should be done async or in a more regulated way with a configurable thread pool or so + new Thread(() -> { + try { + // waiting until all sync codes for all DL CFs are run + countDownLatch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + // now we handle all new DataLoaders + dispatchDLCFImpl(dfeToDispatchSet); + }).start(); + // Only dispatching relevant data loaders + for (DataLoaderCF dlCF : relevantDataLoaderCFs) { + dlCF.dfe.getDataLoader(dlCF.dataLoaderName).dispatch(); + } +// executionContext.getDataLoaderRegistry().dispatchAll(); + } + + + public void newDataLoaderCF(DataLoaderCF dataLoaderCF) { + System.out.println("newDataLoaderCF"); + callStack.lock.runLocked(() -> { + callStack.allDataLoaderCF.add(dataLoaderCF); + }); + if (callStack.fieldsFinishedDispatching.contains(dataLoaderCF.dfe)) { + System.out.println("isolated dispatch"); + dispatchIsolatedDataLoader(dataLoaderCF); + } + + } + + private void dispatchIsolatedDataLoader(DataLoaderCF dlCF) { + callStack.lock.runLocked(() -> { + callStack.batchWindowOfIsolatedDfeToDispatch.add(dlCF.dfe); + if (!callStack.batchWindowOpen) { + callStack.batchWindowOpen = true; + AtomicReference> dfesToDispatch = new AtomicReference<>(); + Runnable runnable = () -> { + callStack.lock.runLocked(() -> { + dfesToDispatch.set(new LinkedHashSet<>(callStack.batchWindowOfIsolatedDfeToDispatch)); + callStack.batchWindowOfIsolatedDfeToDispatch.clear(); + callStack.batchWindowOpen = false; + }); + dispatchDLCFImpl(dfesToDispatch.get()); + }; + isolatedDLCFBatchWindowScheduler.schedule(runnable, BATCH_WINDOW_NANO_SECONDS, TimeUnit.NANOSECONDS); + } + + }); + } + + } diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch.java b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch.java index 26c847b754..115236ebc0 100644 --- a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch.java +++ b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch.java @@ -7,6 +7,7 @@ import graphql.execution.ExecutionStrategyParameters; import graphql.execution.FieldValueInfo; import graphql.schema.DataFetcher; +import graphql.schema.DataFetchingEnvironment; import graphql.util.LockKit; import org.dataloader.DataLoaderRegistry; @@ -14,6 +15,7 @@ import java.util.List; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; /** * The execution of a query can be divided into 2 phases: first, the non-deferred fields are executed and only once @@ -173,7 +175,7 @@ public void executeObjectOnFieldValuesException(Throwable t, ExecutionStrategyPa public void fieldFetched(ExecutionContext executionContext, ExecutionStrategyParameters parameters, DataFetcher dataFetcher, - Object fetchedValue) { + Object fetchedValue, Supplier dataFetchingEnvironment) { final boolean dispatchNeeded; diff --git a/src/test/groovy/graphql/DataLoaderCFTest.groovy b/src/test/groovy/graphql/DataLoaderCFTest.groovy new file mode 100644 index 0000000000..f799202a5b --- /dev/null +++ b/src/test/groovy/graphql/DataLoaderCFTest.groovy @@ -0,0 +1,287 @@ +package graphql + +import graphql.execution.instrumentation.dataloader.DataLoaderCF +import graphql.schema.DataFetcher +import org.dataloader.BatchLoader +import org.dataloader.DataLoader +import org.dataloader.DataLoaderFactory +import org.dataloader.DataLoaderRegistry +import spock.lang.Specification + +import java.util.concurrent.CompletableFuture + +import static graphql.ExecutionInput.newExecutionInput + +class DataLoaderCFTest extends Specification { + + + def "chained data loaders"() { + given: + def sdl = ''' + + type Query { + dogName: String + catName: String + } + ''' + int batchLoadCalls = 0 + BatchLoader batchLoader = { keys -> + return CompletableFuture.supplyAsync { + batchLoadCalls++ + Thread.sleep(250) + println "BatchLoader called with keys: $keys" + assert keys.size() == 2 + return ["Luna", "Tiger"] + } + } + + DataLoader nameDataLoader = DataLoaderFactory.newDataLoader(batchLoader); + + DataLoaderRegistry dataLoaderRegistry = new DataLoaderRegistry(); + dataLoaderRegistry.register("name", nameDataLoader); + + def df1 = { env -> + return DataLoaderCF.newDataLoaderCF(env, "name", "Key1").thenCompose { + result -> + { + return DataLoaderCF.newDataLoaderCF(env, "name", result) + } + } + } as DataFetcher + + def df2 = { env -> + return DataLoaderCF.newDataLoaderCF(env, "name", "Key2").thenCompose { + result -> + { + return DataLoaderCF.newDataLoaderCF(env, "name", result) + } + } + } as DataFetcher + + + def fetchers = ["Query": ["dogName": df1, "catName": df2]] + def schema = TestUtil.schema(sdl, fetchers) + def graphQL = GraphQL.newGraphQL(schema).build() + + def query = "{ dogName catName } " + def ei = newExecutionInput(query).dataLoaderRegistry(dataLoaderRegistry).build() + + when: + def er = graphQL.execute(ei) + then: + er.data == [dogName: "Luna", catName: "Tiger"] + batchLoadCalls == 2 + } + + def "more complicated chained data loader for one DF"() { + given: + def sdl = ''' + + type Query { + foo: String + } + ''' + int batchLoadCalls1 = 0 + BatchLoader batchLoader1 = { keys -> + return CompletableFuture.supplyAsync { + batchLoadCalls1++ + Thread.sleep(250) + println "BatchLoader1 called with keys: $keys" + return keys.collect { String key -> + key + "-batchloader1" + } + } + } + int batchLoadCalls2 = 0 + BatchLoader batchLoader2 = { keys -> + return CompletableFuture.supplyAsync { + batchLoadCalls2++ + Thread.sleep(250) + println "BatchLoader2 called with keys: $keys" + return keys.collect { String key -> + key + "-batchloader2" + } + } + } + + + DataLoader dl1 = DataLoaderFactory.newDataLoader(batchLoader1); + DataLoader dl2 = DataLoaderFactory.newDataLoader(batchLoader2); + + DataLoaderRegistry dataLoaderRegistry = new DataLoaderRegistry(); + dataLoaderRegistry.register("dl1", dl1); + dataLoaderRegistry.register("dl2", dl2); + + def df = { env -> + return DataLoaderCF.newDataLoaderCF(env, "dl1", "start").thenCompose { + firstDLResult -> + + def otherCF1 = DataLoaderCF.supplyAsyncDataLoaderCF(env, { + Thread.sleep(1000) + return "otherCF1" + }) + def otherCF2 = DataLoaderCF.supplyAsyncDataLoaderCF(env, { + Thread.sleep(1000) + return "otherCF2" + }) + + def secondDL = DataLoaderCF.newDataLoaderCF(env, "dl2", firstDLResult).thenApply { + secondDLResult -> + return secondDLResult + "-apply" + } + return otherCF1.thenCompose { + otherCF1Result -> + otherCF2.thenCompose { + otherCF2Result -> + secondDL.thenApply { + secondDLResult -> + return firstDLResult + "-" + otherCF1Result + "-" + otherCF2Result + "-" + secondDLResult + } + } + } + + } + } as DataFetcher + + + def fetchers = ["Query": ["foo": df]] + def schema = TestUtil.schema(sdl, fetchers) + def graphQL = GraphQL.newGraphQL(schema).build() + + def query = "{ foo } " + def ei = newExecutionInput(query).dataLoaderRegistry(dataLoaderRegistry).build() + + when: + def er = graphQL.execute(ei) + then: + er.data == [foo: "start-batchloader1-otherCF1-otherCF2-start-batchloader1-batchloader2-apply"] + batchLoadCalls1 == 1 + batchLoadCalls2 == 1 + } + + + def "chained data loaders with an isolated data loader"() { + given: + def sdl = ''' + + type Query { + dogName: String + catName: String + } + ''' + int batchLoadCalls = 0 + BatchLoader batchLoader = { keys -> + return CompletableFuture.supplyAsync { + batchLoadCalls++ + Thread.sleep(250) + println "BatchLoader called with keys: $keys" + return keys.collect { String key -> + key.substring(0, key.length() - 1) + (Integer.parseInt(key.substring(key.length() - 1, key.length())) + 1) + } + } + } + + DataLoader nameDataLoader = DataLoaderFactory.newDataLoader(batchLoader); + + DataLoaderRegistry dataLoaderRegistry = new DataLoaderRegistry(); + dataLoaderRegistry.register("name", nameDataLoader); + + def df1 = { env -> + return DataLoaderCF.newDataLoaderCF(env, "name", "Luna0").thenCompose { + result -> + { + return DataLoaderCF.supplyAsyncDataLoaderCF(env, { + Thread.sleep(1000) + return "foo" + }).thenCompose { + return DataLoaderCF.newDataLoaderCF(env, "name", result) + } + } + } + } as DataFetcher + + def df2 = { env -> + return DataLoaderCF.newDataLoaderCF(env, "name", "Tiger0").thenCompose { + result -> + { + return DataLoaderCF.newDataLoaderCF(env, "name", result) + } + } + } as DataFetcher + + + def fetchers = ["Query": ["dogName": df1, "catName": df2]] + def schema = TestUtil.schema(sdl, fetchers) + def graphQL = GraphQL.newGraphQL(schema).build() + + def query = "{ dogName catName } " + def ei = newExecutionInput(query).dataLoaderRegistry(dataLoaderRegistry).build() + + when: + def er = graphQL.execute(ei) + then: + er.data == [dogName: "Luna2", catName: "Tiger2"] + batchLoadCalls == 3 + } + + def "chained data loaders with two isolated data loaders"() { + // TODO: this test is naturally flaky, because there is no guarantee that the Thread.sleep(1000) finish close + // enough time wise to be batched together + given: + def sdl = ''' + + type Query { + foo: String + bar: String + } + ''' + int batchLoadCalls = 0 + BatchLoader batchLoader = { keys -> + return CompletableFuture.supplyAsync { + batchLoadCalls++ + Thread.sleep(250) + println "BatchLoader called with keys: $keys" + return keys; + } + } + + DataLoader nameDataLoader = DataLoaderFactory.newDataLoader(batchLoader); + + DataLoaderRegistry dataLoaderRegistry = new DataLoaderRegistry(); + dataLoaderRegistry.register("dl", nameDataLoader); + + def fooDF = { env -> + return DataLoaderCF.supplyAsyncDataLoaderCF(env, { + Thread.sleep(1000) + return "fooFirstValue" + }).thenCompose { + return DataLoaderCF.newDataLoaderCF(env, "dl", it) + } + } as DataFetcher + + def barDF = { env -> + return DataLoaderCF.supplyAsyncDataLoaderCF(env, { + Thread.sleep(1000) + return "barFirstValue" + }).thenCompose { + return DataLoaderCF.newDataLoaderCF(env, "dl", it) + } + } as DataFetcher + + + def fetchers = ["Query": ["foo": fooDF, "bar": barDF]] + def schema = TestUtil.schema(sdl, fetchers) + def graphQL = GraphQL.newGraphQL(schema).build() + + def query = "{ foo bar } " + def ei = newExecutionInput(query).dataLoaderRegistry(dataLoaderRegistry).build() + + when: + def er = graphQL.execute(ei) + then: + er.data == [foo: "fooFirstValue", bar: "barFirstValue"] + batchLoadCalls == 1 + } + + +} From f6558932d44f5b9fd2f852178093657202dcfdb1 Mon Sep 17 00:00:00 2001 From: Andreas Marek Date: Sun, 30 Mar 2025 21:33:52 +1000 Subject: [PATCH 2/3] replace countdown latch with an async solution --- .../dataloader/DataLoaderCF.java | 7 ++---- .../PerLevelDataLoaderDispatchStrategy.java | 25 +++++++------------ 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/DataLoaderCF.java b/src/main/java/graphql/execution/instrumentation/dataloader/DataLoaderCF.java index e72090c7e0..674168e2fc 100644 --- a/src/main/java/graphql/execution/instrumentation/dataloader/DataLoaderCF.java +++ b/src/main/java/graphql/execution/instrumentation/dataloader/DataLoaderCF.java @@ -7,7 +7,6 @@ import graphql.schema.DataFetchingEnvironment; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; import java.util.function.Supplier; import static graphql.execution.Execution.EXECUTION_CONTEXT_KEY; @@ -19,7 +18,7 @@ public class DataLoaderCF extends CompletableFuture { final Object key; final CompletableFuture dataLoaderCF; - volatile CountDownLatch latch; + final CompletableFuture finishedSyncDependents = new CompletableFuture(); public DataLoaderCF(DataFetchingEnvironment dfe, String dataLoaderName, Object key) { this.dfe = dfe; @@ -35,9 +34,7 @@ public DataLoaderCF(DataFetchingEnvironment dfe, String dataLoaderName, Object k complete((T) value); } // post completion hook - if (latch != null) { - latch.countDown(); - } + finishedSyncDependents.complete(null); }); } else { dataLoaderCF = null; diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java index deb00a0b79..44a8ae4896 100644 --- a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java +++ b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java @@ -16,9 +16,9 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -323,9 +323,11 @@ public void dispatchDLCFImpl(Set dfeToDispatchSet) { // filter out all DataLoaderCFS that are matching the fields we want to dispatch List> relevantDataLoaderCFs = new ArrayList<>(); + List> finishedSyncDependentsCFs = new ArrayList<>(); for (DataLoaderCF dataLoaderCF : callStack.allDataLoaderCF) { if (dfeToDispatchSet.contains(dataLoaderCF.dfe)) { relevantDataLoaderCFs.add(dataLoaderCF); + finishedSyncDependentsCFs.add(dataLoaderCF.finishedSyncDependents); } } // we are cleaning up the list of all DataLoadersCFs @@ -338,21 +340,12 @@ public void dispatchDLCFImpl(Set dfeToDispatchSet) { } // we are dispatching all data loaders and waiting for all dataLoaderCFs to complete // and to finish their sync actions - CountDownLatch countDownLatch = new CountDownLatch(relevantDataLoaderCFs.size()); - for (DataLoaderCF dlCF : relevantDataLoaderCFs) { - dlCF.latch = countDownLatch; - } - // TODO: this should be done async or in a more regulated way with a configurable thread pool or so - new Thread(() -> { - try { - // waiting until all sync codes for all DL CFs are run - countDownLatch.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - // now we handle all new DataLoaders - dispatchDLCFImpl(dfeToDispatchSet); - }).start(); + + CompletableFuture + .allOf(finishedSyncDependentsCFs.toArray(new CompletableFuture[0])) + .whenComplete((unused, throwable) -> + dispatchDLCFImpl(dfeToDispatchSet) + ); // Only dispatching relevant data loaders for (DataLoaderCF dlCF : relevantDataLoaderCFs) { dlCF.dfe.getDataLoader(dlCF.dataLoaderName).dispatch(); From dc195320ca5fb1021a9388def3cd8bf4df61ea95 Mon Sep 17 00:00:00 2001 From: Andreas Marek Date: Sun, 30 Mar 2025 22:22:41 +1000 Subject: [PATCH 3/3] only use window batching for DataLoaderCFs --- .../AbstractAsyncExecutionStrategy.java | 4 +- .../execution/AsyncExecutionStrategy.java | 8 +++ .../graphql/execution/ExecutionContext.java | 15 ++++++ .../graphql/execution/ExecutionStrategy.java | 25 +++++++++- .../PerLevelDataLoaderDispatchStrategy.java | 50 ++++++++++++------- .../groovy/graphql/DataLoaderCFTest.groovy | 1 + 6 files changed, 83 insertions(+), 20 deletions(-) diff --git a/src/main/java/graphql/execution/AbstractAsyncExecutionStrategy.java b/src/main/java/graphql/execution/AbstractAsyncExecutionStrategy.java index 863e0d6fad..3b1490d3c2 100644 --- a/src/main/java/graphql/execution/AbstractAsyncExecutionStrategy.java +++ b/src/main/java/graphql/execution/AbstractAsyncExecutionStrategy.java @@ -5,7 +5,6 @@ import graphql.ExecutionResultImpl; import graphql.PublicSpi; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -24,8 +23,10 @@ public AbstractAsyncExecutionStrategy(DataFetcherExceptionHandler dataFetcherExc protected BiConsumer, Throwable> handleResults(ExecutionContext executionContext, List fieldNames, CompletableFuture overallResult) { return (List results, Throwable exception) -> { + executionContext.running(); if (exception != null) { handleNonNullException(executionContext, overallResult, exception); + executionContext.finished(); return; } Map resolvedValuesByField = Maps.newLinkedHashMapWithExpectedSize(fieldNames.size()); @@ -35,6 +36,7 @@ protected BiConsumer, Throwable> handleResults(ExecutionContext exe resolvedValuesByField.put(fieldName, result); } overallResult.complete(new ExecutionResultImpl(resolvedValuesByField, executionContext.getErrors())); + executionContext.finished(); }; } } diff --git a/src/main/java/graphql/execution/AsyncExecutionStrategy.java b/src/main/java/graphql/execution/AsyncExecutionStrategy.java index bbd4a9cf68..f665938404 100644 --- a/src/main/java/graphql/execution/AsyncExecutionStrategy.java +++ b/src/main/java/graphql/execution/AsyncExecutionStrategy.java @@ -38,6 +38,7 @@ public AsyncExecutionStrategy(DataFetcherExceptionHandler exceptionHandler) { @Override @SuppressWarnings("FutureReturnValueIgnored") public CompletableFuture execute(ExecutionContext executionContext, ExecutionStrategyParameters parameters) throws NonNullableFieldWasNullException { + executionContext.running(); DataLoaderDispatchStrategy dataLoaderDispatcherStrategy = executionContext.getDataLoaderDispatcherStrategy(); dataLoaderDispatcherStrategy.executionStrategy(executionContext, parameters); Instrumentation instrumentation = executionContext.getInstrumentation(); @@ -50,6 +51,7 @@ public CompletableFuture execute(ExecutionContext executionCont Optional isNotSensible = Introspection.isIntrospectionSensible(fields, executionContext); if (isNotSensible.isPresent()) { + executionContext.finished(); return CompletableFuture.completedFuture(isNotSensible.get()); } @@ -60,11 +62,13 @@ public CompletableFuture execute(ExecutionContext executionCont executionStrategyCtx.onDispatched(); futures.await().whenComplete((completeValueInfos, throwable) -> { + executionContext.running(); List fieldsExecutedOnInitialResult = deferredExecutionSupport.getNonDeferredFieldNames(fieldNames); BiConsumer, Throwable> handleResultsConsumer = handleResults(executionContext, fieldsExecutedOnInitialResult, overallResult); if (throwable != null) { handleResultsConsumer.accept(null, throwable.getCause()); + executionContext.finished(); return; } @@ -75,17 +79,21 @@ public CompletableFuture execute(ExecutionContext executionCont dataLoaderDispatcherStrategy.executionStrategyOnFieldValuesInfo(completeValueInfos); executionStrategyCtx.onFieldValuesInfo(completeValueInfos); fieldValuesFutures.await().whenComplete(handleResultsConsumer); + executionContext.finished(); }).exceptionally((ex) -> { + executionContext.running(); // if there are any issues with combining/handling the field results, // complete the future at all costs and bubble up any thrown exception so // the execution does not hang. dataLoaderDispatcherStrategy.executionStrategyOnFieldValuesException(ex); executionStrategyCtx.onFieldValuesException(); overallResult.completeExceptionally(ex); + executionContext.finished(); return null; }); overallResult.whenComplete(executionStrategyCtx::onCompleted); + executionContext.finished(); return overallResult; } } diff --git a/src/main/java/graphql/execution/ExecutionContext.java b/src/main/java/graphql/execution/ExecutionContext.java index b166b19025..f90b03114a 100644 --- a/src/main/java/graphql/execution/ExecutionContext.java +++ b/src/main/java/graphql/execution/ExecutionContext.java @@ -28,6 +28,7 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Supplier; @@ -63,6 +64,8 @@ public class ExecutionContext { private final Supplier queryTree; private final boolean propagateErrorsOnNonNullContractFailure; + private final AtomicInteger isRunning = new AtomicInteger(0); + // this is modified after creation so it needs to be volatile to ensure visibility across Threads private volatile DataLoaderDispatchStrategy dataLoaderDispatcherStrategy = DataLoaderDispatchStrategy.NO_OP; @@ -349,4 +352,16 @@ public ExecutionContext transform(Consumer builderConsu public ResultNodesInfo getResultNodesInfo() { return resultNodesInfo; } + + public boolean isRunning() { + return isRunning.get() > 0; + } + + public void running() { + isRunning.incrementAndGet(); + } + + public void finished() { + isRunning.decrementAndGet(); + } } diff --git a/src/main/java/graphql/execution/ExecutionStrategy.java b/src/main/java/graphql/execution/ExecutionStrategy.java index d647f650d1..5e94cf3c9e 100644 --- a/src/main/java/graphql/execution/ExecutionStrategy.java +++ b/src/main/java/graphql/execution/ExecutionStrategy.java @@ -201,6 +201,7 @@ public static String mkNameForPath(List currentField) { @SuppressWarnings("unchecked") @DuckTyped(shape = "CompletableFuture> | Map") protected Object executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters) throws NonNullableFieldWasNullException { + executionContext.running(); DataLoaderDispatchStrategy dataLoaderDispatcherStrategy = executionContext.getDataLoaderDispatcherStrategy(); dataLoaderDispatcherStrategy.executeObject(executionContext, parameters); Instrumentation instrumentation = executionContext.getInstrumentation(); @@ -225,8 +226,10 @@ protected Object executeObject(ExecutionContext executionContext, ExecutionStrat if (fieldValueInfosResult instanceof CompletableFuture) { CompletableFuture> fieldValueInfos = (CompletableFuture>) fieldValueInfosResult; fieldValueInfos.whenComplete((completeValueInfos, throwable) -> { + executionContext.running(); if (throwable != null) { handleResultsConsumer.accept(null, throwable); + executionContext.finished(); return; } @@ -234,6 +237,7 @@ protected Object executeObject(ExecutionContext executionContext, ExecutionStrat dataLoaderDispatcherStrategy.executeObjectOnFieldValuesInfo(completeValueInfos, parameters); resolveObjectCtx.onFieldValuesInfo(completeValueInfos); resultFutures.await().whenComplete(handleResultsConsumer); + executionContext.finished(); }).exceptionally((ex) -> { // if there are any issues with combining/handling the field results, // complete the future at all costs and bubble up any thrown exception so @@ -244,6 +248,7 @@ protected Object executeObject(ExecutionContext executionContext, ExecutionStrat return null; }); overallResult.whenComplete(resolveObjectCtx::onCompleted); + executionContext.finished(); return overallResult; } else { List completeValueInfos = (List) fieldValueInfosResult; @@ -257,10 +262,12 @@ protected Object executeObject(ExecutionContext executionContext, ExecutionStrat CompletableFuture> completedValues = (CompletableFuture>) completedValuesObject; completedValues.whenComplete(handleResultsConsumer); overallResult.whenComplete(resolveObjectCtx::onCompleted); + executionContext.finished(); return overallResult; } else { Map fieldValueMap = buildFieldValueMap(fieldsExecutedOnInitialResult, (List) completedValuesObject); resolveObjectCtx.onCompleted(fieldValueMap, null); + executionContext.finished(); return fieldValueMap; } } @@ -276,12 +283,15 @@ protected Object executeObject(ExecutionContext executionContext, ExecutionStrat private BiConsumer, Throwable> buildFieldValueMap(List fieldNames, CompletableFuture> overallResult, ExecutionContext executionContext) { return (List results, Throwable exception) -> { + executionContext.running(); if (exception != null) { handleValueException(overallResult, exception, executionContext); + executionContext.finished(); return; } Map resolvedValuesByField = buildFieldValueMap(fieldNames, results); overallResult.complete(resolvedValuesByField); + executionContext.finished(); }; } @@ -509,11 +519,15 @@ private Object fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext exec CompletableFuture fetchedValue = (CompletableFuture) fetchedObject; return fetchedValue .handle((result, exception) -> { + executionContext.running(); fetchCtx.onCompleted(result, exception); if (exception != null) { - return handleFetchingException(dataFetchingEnvironment.get(), parameters, exception); + CompletableFuture handleFetchingExceptionResult = handleFetchingException(dataFetchingEnvironment.get(), parameters, exception); + executionContext.finished(); + return handleFetchingExceptionResult; } else { // we can simply return the fetched value CF and avoid a allocation + executionContext.finished(); return fetchedValue; } }) @@ -553,7 +567,7 @@ protected Supplier getNormalizedField(ExecutionContex protected FetchedValue unboxPossibleDataFetcherResult(ExecutionContext executionContext, ExecutionStrategyParameters parameters, Object result) { - + executionContext.running(); if (result instanceof DataFetcherResult) { DataFetcherResult dataFetcherResult = (DataFetcherResult) result; @@ -567,9 +581,11 @@ protected FetchedValue unboxPossibleDataFetcherResult(ExecutionContext execution localContext = parameters.getLocalContext(); } Object unBoxedValue = executionContext.getValueUnboxer().unbox(dataFetcherResult.getData()); + executionContext.finished(); return new FetchedValue(unBoxedValue, dataFetcherResult.getErrors(), localContext); } else { Object unBoxedValue = executionContext.getValueUnboxer().unbox(result); + executionContext.finished(); return new FetchedValue(unBoxedValue, ImmutableList.of(), parameters.getLocalContext()); } } @@ -638,6 +654,7 @@ protected FieldValueInfo completeField(ExecutionContext executionContext, Execut } private FieldValueInfo completeField(GraphQLFieldDefinition fieldDef, ExecutionContext executionContext, ExecutionStrategyParameters parameters, FetchedValue fetchedValue) { + executionContext.running(); GraphQLObjectType parentType = (GraphQLObjectType) parameters.getExecutionStepInfo().getUnwrappedNonNullType(); ExecutionStepInfo executionStepInfo = createExecutionStepInfo(executionContext, parameters, fieldDef, parentType); @@ -661,6 +678,7 @@ private FieldValueInfo completeField(GraphQLFieldDefinition fieldDef, ExecutionC CompletableFuture executionResultFuture = fieldValueInfo.getFieldValueFuture(); ctxCompleteField.onDispatched(); executionResultFuture.whenComplete(ctxCompleteField::onCompleted); + executionContext.finished(); return fieldValueInfo; } @@ -833,13 +851,16 @@ protected FieldValueInfo completeValueForList(ExecutionContext executionContext, overallResult.whenComplete(completeListCtx::onCompleted); resultsFuture.whenComplete((results, exception) -> { + executionContext.running(); if (exception != null) { + executionContext.finished(); handleValueException(overallResult, exception, executionContext); return; } List completedResults = new ArrayList<>(results.size()); completedResults.addAll(results); overallResult.complete(completedResults); + executionContext.finished(); }); listOrPromiseToList = overallResult; } else { diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java index 44a8ae4896..a1c85a3301 100644 --- a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java +++ b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java @@ -3,6 +3,7 @@ import graphql.Assert; import graphql.Internal; import graphql.execution.DataLoaderDispatchStrategy; +import graphql.execution.Execution; import graphql.execution.ExecutionContext; import graphql.execution.ExecutionStrategyParameters; import graphql.execution.FieldValueInfo; @@ -32,7 +33,7 @@ public class PerLevelDataLoaderDispatchStrategy implements DataLoaderDispatchStr private final ExecutionContext executionContext; static final ScheduledExecutorService isolatedDLCFBatchWindowScheduler = Executors.newSingleThreadScheduledExecutor(); - static final int BATCH_WINDOW_NANO_SECONDS = 500_000; + static final int BATCH_WINDOW_NANO_SECONDS = 100_000; private static class CallStack { @@ -310,9 +311,8 @@ private boolean levelReady(int level) { } void dispatch(int level) { - if (callStack.levelToDFEWithDataLoaderCF.size() > 0) { - dispatchDLCFImpl(callStack.levelToDFEWithDataLoaderCF.get(level)); - } else { + // only dispatch if we don't use any DataLoaderCFs + if (callStack.levelToDFEWithDataLoaderCF.size() == 0) { DataLoaderRegistry dataLoaderRegistry = executionContext.getDataLoaderRegistry(); dataLoaderRegistry.dispatchAll(); } @@ -359,28 +359,44 @@ public void newDataLoaderCF(DataLoaderCF dataLoaderCF) { callStack.lock.runLocked(() -> { callStack.allDataLoaderCF.add(dataLoaderCF); }); - if (callStack.fieldsFinishedDispatching.contains(dataLoaderCF.dfe)) { - System.out.println("isolated dispatch"); - dispatchIsolatedDataLoader(dataLoaderCF); +// if (callStack.fieldsFinishedDispatching.contains(dataLoaderCF.dfe)) { + System.out.println("isolated dispatch"); + dispatchIsolatedDataLoader(dataLoaderCF); +// } + + } + + class TriggerDispatch implements Runnable { + + final ExecutionContext executionContext; + + TriggerDispatch(ExecutionContext executionContext) { + this.executionContext = executionContext; } + @Override + public void run() { + if (executionContext.isRunning()) { + isolatedDLCFBatchWindowScheduler.schedule(this, BATCH_WINDOW_NANO_SECONDS, TimeUnit.NANOSECONDS); + return; + } + AtomicReference> dfesToDispatch = new AtomicReference<>(); + callStack.lock.runLocked(() -> { + dfesToDispatch.set(new LinkedHashSet<>(callStack.batchWindowOfIsolatedDfeToDispatch)); + callStack.batchWindowOfIsolatedDfeToDispatch.clear(); + callStack.batchWindowOpen = false; + }); + dispatchDLCFImpl(dfesToDispatch.get()); + } } private void dispatchIsolatedDataLoader(DataLoaderCF dlCF) { callStack.lock.runLocked(() -> { callStack.batchWindowOfIsolatedDfeToDispatch.add(dlCF.dfe); + ExecutionContext executionContext = dlCF.dfe.getGraphQlContext().get(Execution.EXECUTION_CONTEXT_KEY); if (!callStack.batchWindowOpen) { callStack.batchWindowOpen = true; - AtomicReference> dfesToDispatch = new AtomicReference<>(); - Runnable runnable = () -> { - callStack.lock.runLocked(() -> { - dfesToDispatch.set(new LinkedHashSet<>(callStack.batchWindowOfIsolatedDfeToDispatch)); - callStack.batchWindowOfIsolatedDfeToDispatch.clear(); - callStack.batchWindowOpen = false; - }); - dispatchDLCFImpl(dfesToDispatch.get()); - }; - isolatedDLCFBatchWindowScheduler.schedule(runnable, BATCH_WINDOW_NANO_SECONDS, TimeUnit.NANOSECONDS); + isolatedDLCFBatchWindowScheduler.schedule(new TriggerDispatch(executionContext), BATCH_WINDOW_NANO_SECONDS, TimeUnit.NANOSECONDS); } }); diff --git a/src/test/groovy/graphql/DataLoaderCFTest.groovy b/src/test/groovy/graphql/DataLoaderCFTest.groovy index f799202a5b..a5e39d701b 100644 --- a/src/test/groovy/graphql/DataLoaderCFTest.groovy +++ b/src/test/groovy/graphql/DataLoaderCFTest.groovy @@ -1,5 +1,6 @@ package graphql + import graphql.execution.instrumentation.dataloader.DataLoaderCF import graphql.schema.DataFetcher import org.dataloader.BatchLoader