diff --git a/src/transformation/visitors/function.ts b/src/transformation/visitors/function.ts index 1b1735324..0061ea749 100644 --- a/src/transformation/visitors/function.ts +++ b/src/transformation/visitors/function.ts @@ -115,7 +115,10 @@ export function transformFunctionBody( ): [lua.Statement[], Scope] { const scope = pushScope(context, ScopeType.Function); scope.node = node; - const bodyStatements = transformFunctionBodyContent(context, body); + let bodyStatements = transformFunctionBodyContent(context, body); + if (node && isAsyncFunction(node)) { + bodyStatements = wrapInAsyncAwaiter(context, bodyStatements); + } const headerStatements = transformFunctionBodyHeader(context, scope, parameters, spreadIdentifier); popScope(context); return [[...headerStatements, ...bodyStatements], scope]; @@ -197,10 +200,8 @@ export function transformFunctionToExpression( node ); - const possiblyAsyncBody = isAsyncFunction(node) ? wrapInAsyncAwaiter(context, transformedBody) : transformedBody; - const functionExpression = lua.createFunctionExpression( - lua.createBlock(possiblyAsyncBody), + lua.createBlock(transformedBody), paramNames, dotsLiteral, flags, diff --git a/src/transformation/visitors/spread.ts b/src/transformation/visitors/spread.ts index 6e6ec391a..f8ebdc79d 100644 --- a/src/transformation/visitors/spread.ts +++ b/src/transformation/visitors/spread.ts @@ -37,6 +37,11 @@ export function isOptimizedVarArgSpread(context: TransformationContext, symbol: return false; } + // Scope cannot be an async function + if (scope.node.modifiers?.some(m => m.kind === ts.SyntaxKind.AsyncKeyword)) { + return false; + } + // Identifier must be a vararg in the local function scope's parameters const isSpreadParameter = (p: ts.ParameterDeclaration) => p.dotDotDotToken && ts.isIdentifier(p.name) && context.checker.getSymbolAtLocation(p.name) === symbol; diff --git a/test/unit/builtins/async-await.spec.ts b/test/unit/builtins/async-await.spec.ts index a396c4aed..0a9c3513e 100644 --- a/test/unit/builtins/async-await.spec.ts +++ b/test/unit/builtins/async-await.spec.ts @@ -333,3 +333,39 @@ test.each([ .setOptions({ module: ModuleKind.ESNext, target: ScriptTarget.ES2017 }) .expectToHaveDiagnostics([awaitMustBeInAsyncFunction.code]); }); + +test("async function can access varargs", () => { + util.testFunction` + const { promise, resolve } = defer(); + + async function a(...args: string[]) { + log(await promise); + log(args[1]); + } + + const awaitingPromise = a("A", "B", "C"); + resolve("resolved"); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["resolved", "B"]); +}); + +test("async function can forward varargs", () => { + util.testFunction` + const { promise, resolve } = defer(); + + async function a(...args: string[]) { + log(await promise); + log(...args); + } + + const awaitingPromise = a("A", "B", "C"); + resolve("resolved"); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["resolved", "A", "B", "C"]); +});