diff --git a/src/main/java/graphql/execution/ExecutionStrategy.java b/src/main/java/graphql/execution/ExecutionStrategy.java index 06d3b644b..78ad4280b 100644 --- a/src/main/java/graphql/execution/ExecutionStrategy.java +++ b/src/main/java/graphql/execution/ExecutionStrategy.java @@ -489,12 +489,17 @@ private Object fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext exec CompletableFuture> handleCF = engineRunningState.handle(fetchedValue, (result, exception) -> { // because we added an artificial CF, we need to unwrap the exception - fetchCtx.onCompleted(result, exception); - exception = engineRunningState.possibleCancellation(exception); - - if (exception != null) { - return handleFetchingException(dataFetchingEnvironment.get(), parameters, exception); + Throwable possibleWrappedException = engineRunningState.possibleCancellation(exception); + + if (possibleWrappedException != null) { + CompletableFuture> handledExceptionResult = handleFetchingException(dataFetchingEnvironment.get(), parameters, possibleWrappedException); + return handledExceptionResult.thenApply( handledResult -> { + fetchCtx.onExceptionHandled(handledResult); + fetchCtx.onCompleted(result, exception); + return handledResult; + }); } else { + fetchCtx.onCompleted(result, exception); // we can simply return the fetched value CF and avoid a allocation return fetchedValue; } @@ -578,7 +583,7 @@ private void addExtensionsIfPresent(ExecutionContext executionContext, DataFetch } } - protected CompletableFuture handleFetchingException( + protected CompletableFuture> handleFetchingException( DataFetchingEnvironment environment, ExecutionStrategyParameters parameters, Throwable e @@ -599,10 +604,10 @@ protected CompletableFuture handleFetchingException( } } - private CompletableFuture asyncHandleException(DataFetcherExceptionHandler handler, DataFetcherExceptionHandlerParameters handlerParameters) { + private CompletableFuture> asyncHandleException(DataFetcherExceptionHandler handler, DataFetcherExceptionHandlerParameters handlerParameters) { //noinspection unchecked return handler.handleException(handlerParameters).thenApply( - handlerResult -> (T) DataFetcherResult.newResult().errors(handlerResult.getErrors()).build() + handlerResult -> (DataFetcherResult) DataFetcherResult.newResult().errors(handlerResult.getErrors()).build() ); } diff --git a/src/main/java/graphql/execution/instrumentation/ChainedInstrumentation.java b/src/main/java/graphql/execution/instrumentation/ChainedInstrumentation.java index 2779a10a0..4ae65742b 100644 --- a/src/main/java/graphql/execution/instrumentation/ChainedInstrumentation.java +++ b/src/main/java/graphql/execution/instrumentation/ChainedInstrumentation.java @@ -6,6 +6,7 @@ import graphql.ExperimentalApi; import graphql.PublicApi; import graphql.execution.Async; +import graphql.execution.DataFetcherResult; import graphql.execution.ExecutionContext; import graphql.execution.FieldValueInfo; import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters; @@ -380,6 +381,11 @@ public void onFetchedValue(Object fetchedValue) { contexts.forEach(context -> context.onFetchedValue(fetchedValue)); } + @Override + public void onExceptionHandled(DataFetcherResult dataFetcherResult) { + contexts.forEach(context -> context.onExceptionHandled(dataFetcherResult)); + } + @Override public void onCompleted(Object result, Throwable t) { contexts.forEach(context -> context.onCompleted(result, t)); diff --git a/src/main/java/graphql/execution/instrumentation/FieldFetchingInstrumentationContext.java b/src/main/java/graphql/execution/instrumentation/FieldFetchingInstrumentationContext.java index 38984c6f9..f6ff09bec 100644 --- a/src/main/java/graphql/execution/instrumentation/FieldFetchingInstrumentationContext.java +++ b/src/main/java/graphql/execution/instrumentation/FieldFetchingInstrumentationContext.java @@ -2,6 +2,7 @@ import graphql.Internal; import graphql.PublicSpi; +import graphql.execution.DataFetcherResult; import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters; import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; @@ -26,6 +27,15 @@ public interface FieldFetchingInstrumentationContext extends InstrumentationCont default void onFetchedValue(Object fetchedValue) { } + /** + * This is called back after any {@link graphql.execution.DataFetcherExceptionHandler}) has run on any exception raised + * during a {@link graphql.schema.DataFetcher} invocation. This allows to see the final {@link DataFetcherResult} + * that will be used when performing the complete step. + * @param dataFetcherResult the final {@link DataFetcherResult} after the exception handler has run + */ + default void onExceptionHandled(DataFetcherResult dataFetcherResult) { + } + @Internal FieldFetchingInstrumentationContext NOOP = new FieldFetchingInstrumentationContext() { @Override diff --git a/src/test/groovy/graphql/execution/instrumentation/ChainedInstrumentationStateTest.groovy b/src/test/groovy/graphql/execution/instrumentation/ChainedInstrumentationStateTest.groovy index 88d7c8653..a0c67d684 100644 --- a/src/test/groovy/graphql/execution/instrumentation/ChainedInstrumentationStateTest.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/ChainedInstrumentationStateTest.groovy @@ -8,6 +8,8 @@ import graphql.execution.AsyncExecutionStrategy import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters import graphql.execution.instrumentation.parameters.InstrumentationValidationParameters +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment import graphql.validation.ValidationError import spock.lang.Specification @@ -227,6 +229,78 @@ class ChainedInstrumentationStateTest extends Specification { assertCalls(c) } + def "basic chaining and state management when exception raised in data fetching"() { + + def a = new NamedInstrumentation("A") + def b = new NamedInstrumentation("B") + def c = new NamedInstrumentation("C") + def nullState = new SimplePerformantInstrumentation() + + def chainedInstrumentation = new ChainedInstrumentation([ + a, + b, + nullState, + c, + ]) + + def query = """ + query HeroNameAndFriendsQuery { + hero { + id + } + } + """ + + def expected = "onExceptionHandled:fetch-id" + + + when: + def strategy = new AsyncExecutionStrategy() + def schema = StarWarsSchema.starWarsSchema + def graphQL = GraphQL + .newGraphQL(schema.transform { schemaBuilder -> + // throw exception when fetching the hero id + def exceptionDataFetcher = new DataFetcher() { + @Override + Object get(DataFetchingEnvironment environment) { + throw new RuntimeException("Data fetcher exception") + } + } + schemaBuilder.codeRegistry(schema.codeRegistry.transform { + it.dataFetcher( + schema.getObjectType("Human"), + schema.getObjectType("Human").getFieldDefinition("id"), + exceptionDataFetcher + ) + it.dataFetcher( + schema.getObjectType("Droid"), + schema.getObjectType("Droid").getFieldDefinition("id"), + exceptionDataFetcher + ) + return it + } + ) + }) + .queryExecutionStrategy(strategy) + .instrumentation(chainedInstrumentation) + .build() + + graphQL.execute(query) + + then: + + chainedInstrumentation.getInstrumentations().size() == 4 + + a.executionList.any { it == expected } + b.executionList.any { it == expected } + c.executionList.any { it == expected } + + assertCalls(a) + assertCalls(b) + assertCalls(c) + + } + def "empty chain"() { def chainedInstrumentation = new ChainedInstrumentation(Arrays.asList()) diff --git a/src/test/groovy/graphql/execution/instrumentation/InstrumentationTest.groovy b/src/test/groovy/graphql/execution/instrumentation/InstrumentationTest.groovy index 43767d934..ca52f8cf8 100644 --- a/src/test/groovy/graphql/execution/instrumentation/InstrumentationTest.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/InstrumentationTest.groovy @@ -1,11 +1,16 @@ package graphql.execution.instrumentation +import graphql.ErrorType import graphql.ExecutionInput import graphql.ExecutionResult import graphql.GraphQL +import graphql.GraphqlErrorBuilder +import graphql.GraphqlErrorBuilderTest import graphql.StarWarsSchema import graphql.TestUtil import graphql.execution.AsyncExecutionStrategy +import graphql.execution.DataFetcherExceptionHandlerResult +import graphql.execution.DataFetcherResult import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters @@ -23,6 +28,8 @@ import spock.lang.Specification import java.util.concurrent.CompletableFuture import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicLong class InstrumentationTest extends Specification { @@ -152,6 +159,165 @@ class InstrumentationTest extends Specification { instrumentation.throwableList[0].getMessage() == "DF BANG!" } + def "field fetch will instrument exceptions correctly - includes exception handling with onExceptionHandled"() { + + given: + + def query = """ + { + hero { + id + } + } + """ + + def instrumentation = new LegacyTestingInstrumentation() { + def onHandledCalled = false + def onCompletedCalled = false + def onDispatchedCalled = false + + @Override + DataFetcher instrumentDataFetcher(DataFetcher dataFetcher, InstrumentationFieldFetchParameters parameters, InstrumentationState state) { + return new DataFetcher() { + @Override + Object get(DataFetchingEnvironment environment) { + throw new RuntimeException("DF BANG!") + } + } + } + + @Override + FieldFetchingInstrumentationContext beginFieldFetching(InstrumentationFieldFetchParameters parameters, InstrumentationState state) { + return new FieldFetchingInstrumentationContext() { + @Override + void onDispatched() { + onDispatchedCalled = true + } + + @Override + void onCompleted(Object result, Throwable t) { + onCompletedCalled = true + } + + @Override + void onExceptionHandled(DataFetcherResult dataFetcherResult) { + onHandledCalled = true + } + } + } + } + + def graphQL = GraphQL + .newGraphQL(StarWarsSchema.starWarsSchema) + .defaultDataFetcherExceptionHandler { it -> + // catch all exceptions and transform to graphql error with a prefixed message + return CompletableFuture.completedFuture( + DataFetcherExceptionHandlerResult.newResult(GraphqlErrorBuilder.newError() + .errorType(ErrorType.DataFetchingException) + .message("Handled " + it.exception.message) + .path(it.path) + .build()) + .build()) + } + .instrumentation(instrumentation) + .build() + + when: + def resp = graphQL.execute(query) + + then: "exception handler turned the exception into a graphql error and message prefixed with Handled" + resp.errors.size() == 1 + resp.errors[0].message == "Handled DF BANG!" + + and: "all instrumentation methods were called" + instrumentation.onDispatchedCalled == true + instrumentation.onCompletedCalled == true + instrumentation.onHandledCalled == true + } + + + def "field fetch verify order and call of all methods"() { + + given: + + def query = """ + { + hero { + id + } + } + """ + + def metric = [] + def instrumentation = new SimplePerformantInstrumentation() { + def timeElapsed = new AtomicInteger() + + @Override + DataFetcher instrumentDataFetcher(DataFetcher dataFetcher, InstrumentationFieldFetchParameters parameters, InstrumentationState state) { + return new DataFetcher() { + @Override + Object get(DataFetchingEnvironment environment) { + // simulate latency + timeElapsed.addAndGet(50) + throw new RuntimeException("DF BANG!") + } + } + } + + @Override + FieldFetchingInstrumentationContext beginFieldFetching(InstrumentationFieldFetchParameters parameters, InstrumentationState state) { + return new FieldFetchingInstrumentationContext() { + def start = 0 + def duration = 0 + def hasError = false + + + @Override + void onDispatched() { + start = 1 + } + + @Override + void onCompleted(Object result, Throwable t) { + duration = timeElapsed.get() - start + metric = [duration, hasError] + } + + @Override + void onExceptionHandled(DataFetcherResult dataFetcherResult) { + hasError = dataFetcherResult.errors != null && !dataFetcherResult.errors.isEmpty() + && dataFetcherResult.errors.any { it.message.contains("Handled") } + } + } + } + } + + def graphQL = GraphQL + .newGraphQL(StarWarsSchema.starWarsSchema) + .defaultDataFetcherExceptionHandler { it -> + // catch all exceptions and transform to graphql error with a prefixed message + return CompletableFuture.completedFuture( + DataFetcherExceptionHandlerResult.newResult(GraphqlErrorBuilder.newError() + .errorType(ErrorType.DataFetchingException) + .message("Handled " + it.exception.message) + .path(it.path) + .build()) + .build()) + } + .instrumentation(instrumentation) + .build() + + when: + def resp = graphQL.execute(query) + + then: "exception handler turned the exception into a graphql error and prefixed its message with 'Handled'" + resp.errors.size() == 1 + resp.errors[0].message == "Handled DF BANG!" + + and: "metric was captured i.e all instrumentation methods were called in the right order" + metric == [49, true] + } + /** * This uses a stop and go pattern and multiple threads. Each time * the execution strategy is invoked, the data fetchers are held diff --git a/src/test/groovy/graphql/execution/instrumentation/TestingFieldFetchingInstrumentationContext.groovy b/src/test/groovy/graphql/execution/instrumentation/TestingFieldFetchingInstrumentationContext.groovy index 50fcaccd2..c9e19ba63 100644 --- a/src/test/groovy/graphql/execution/instrumentation/TestingFieldFetchingInstrumentationContext.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/TestingFieldFetchingInstrumentationContext.groovy @@ -1,9 +1,16 @@ package graphql.execution.instrumentation +import graphql.execution.DataFetcherResult + class TestingFieldFetchingInstrumentationContext extends TestingInstrumentContext implements FieldFetchingInstrumentationContext { TestingFieldFetchingInstrumentationContext(Object op, Object executionList, Object throwableList, Boolean useOnDispatch) { super(op, executionList, throwableList, useOnDispatch) } + + @Override + void onExceptionHandled(DataFetcherResult dataFetcherResult) { + executionList << "onExceptionHandled:$op" + } } diff --git a/src/test/groovy/graphql/execution/instrumentation/TestingInstrumentContext.groovy b/src/test/groovy/graphql/execution/instrumentation/TestingInstrumentContext.groovy index 402fd2aee..3e627b368 100644 --- a/src/test/groovy/graphql/execution/instrumentation/TestingInstrumentContext.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/TestingInstrumentContext.groovy @@ -1,7 +1,5 @@ package graphql.execution.instrumentation -import java.util.concurrent.CompletableFuture - class TestingInstrumentContext implements InstrumentationContext { def op def start = System.currentTimeMillis()