diff --git a/src/LuaLib.ts b/src/LuaLib.ts index 669c68941..9c50019f5 100644 --- a/src/LuaLib.ts +++ b/src/LuaLib.ts @@ -28,6 +28,7 @@ export enum LuaLibFeature { ArrayFlat = "ArrayFlat", ArrayFlatMap = "ArrayFlatMap", ArraySetLength = "ArraySetLength", + Await = "Await", Class = "Class", ClassExtends = "ClassExtends", CloneDescriptor = "CloneDescriptor", @@ -101,6 +102,7 @@ const luaLibDependencies: Partial> = { ArrayConcat: [LuaLibFeature.ArrayIsArray], ArrayFlat: [LuaLibFeature.ArrayConcat, LuaLibFeature.ArrayIsArray], ArrayFlatMap: [LuaLibFeature.ArrayConcat, LuaLibFeature.ArrayIsArray], + Await: [LuaLibFeature.InstanceOf, LuaLibFeature.New], Decorate: [LuaLibFeature.ObjectGetOwnPropertyDescriptor, LuaLibFeature.SetDescriptor, LuaLibFeature.ObjectAssign], DelegatedYield: [LuaLibFeature.StringAccess], Delete: [LuaLibFeature.ObjectGetOwnPropertyDescriptors], diff --git a/src/lualib/Await.ts b/src/lualib/Await.ts new file mode 100644 index 000000000..61f383e96 --- /dev/null +++ b/src/lualib/Await.ts @@ -0,0 +1,52 @@ +// The following is a translation of the TypeScript async awaiter which uses generators and yields. +// For Lua we use coroutines instead. +// +// Source: +// +// var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { +// function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } +// return new (P || (P = Promise))(function (resolve, reject) { +// function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } +// function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } +// function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } +// step((generator = generator.apply(thisArg, _arguments || [])).next()); +// }); +// }; +// + +// eslint-disable-next-line @typescript-eslint/promise-function-async +function __TS__AsyncAwaiter(this: void, generator: (this: void) => void) { + return new Promise((resolve, reject) => { + const asyncCoroutine = coroutine.create(generator); + + // eslint-disable-next-line @typescript-eslint/promise-function-async + function adopt(value: unknown) { + return value instanceof __TS__Promise ? value : Promise.resolve(value); + } + function fulfilled(value) { + const [success, resultOrError] = coroutine.resume(asyncCoroutine, value); + if (success) { + step(resultOrError); + } else { + reject(resultOrError); + } + } + function step(result: unknown) { + if (coroutine.status(asyncCoroutine) === "dead") { + resolve(result); + } else { + adopt(result).then(fulfilled, reason => reject(reason)); + } + } + const [success, resultOrError] = coroutine.resume(asyncCoroutine); + if (success) { + step(resultOrError); + } else { + reject(resultOrError); + } + }); +} + +function __TS__Await(this: void, thing: unknown) { + return coroutine.yield(thing); +} diff --git a/src/transformation/utils/diagnostics.ts b/src/transformation/utils/diagnostics.ts index dbbe9376a..a7a63a717 100644 --- a/src/transformation/utils/diagnostics.ts +++ b/src/transformation/utils/diagnostics.ts @@ -147,3 +147,7 @@ export const annotationDeprecated = createWarningDiagnosticFactory( export const notAllowedOptionalAssignment = createErrorDiagnosticFactory( "The left-hand side of an assignment expression may not be an optional property access." ); + +export const awaitMustBeInAsyncFunction = createErrorDiagnosticFactory( + "Await can only be used inside async functions." +); diff --git a/src/transformation/visitors/async-await.ts b/src/transformation/visitors/async-await.ts new file mode 100644 index 000000000..c6dcc86dc --- /dev/null +++ b/src/transformation/visitors/async-await.ts @@ -0,0 +1,36 @@ +import * as ts from "typescript"; +import * as lua from "../../LuaAST"; +import { FunctionVisitor, TransformationContext } from "../context"; +import { awaitMustBeInAsyncFunction } from "../utils/diagnostics"; +import { importLuaLibFeature, LuaLibFeature, transformLuaLibFunction } from "../utils/lualib"; +import { findFirstNodeAbove } from "../utils/typescript"; + +export const transformAwaitExpression: FunctionVisitor = (node, context) => { + // Check if await is inside an async function, it is not allowed at top level or in non-async functions + const containingFunction = findFirstNodeAbove(node, ts.isFunctionLike); + if ( + containingFunction === undefined || + !containingFunction.modifiers?.some(m => m.kind === ts.SyntaxKind.AsyncKeyword) + ) { + context.diagnostics.push(awaitMustBeInAsyncFunction(node)); + } + + const expression = context.transformExpression(node.expression); + return transformLuaLibFunction(context, LuaLibFeature.Await, node, expression); +}; + +export function isAsyncFunction(declaration: ts.FunctionLikeDeclaration): boolean { + return declaration.modifiers?.some(m => m.kind === ts.SyntaxKind.AsyncKeyword) ?? false; +} + +export function wrapInAsyncAwaiter(context: TransformationContext, statements: lua.Statement[]): lua.Statement[] { + importLuaLibFeature(context, LuaLibFeature.Await); + + return [ + lua.createReturnStatement([ + lua.createCallExpression(lua.createIdentifier("__TS__AsyncAwaiter"), [ + lua.createFunctionExpression(lua.createBlock(statements)), + ]), + ]), + ]; +} diff --git a/src/transformation/visitors/function.ts b/src/transformation/visitors/function.ts index cb1d29bfd..0061ea749 100644 --- a/src/transformation/visitors/function.ts +++ b/src/transformation/visitors/function.ts @@ -13,6 +13,7 @@ import { } from "../utils/lua-ast"; import { LuaLibFeature, transformLuaLibFunction } from "../utils/lualib"; import { peekScope, performHoisting, popScope, pushScope, Scope, ScopeType } from "../utils/scope"; +import { isAsyncFunction, wrapInAsyncAwaiter } from "./async-await"; import { transformIdentifier } from "./identifier"; import { transformExpressionBodyToReturnStatement } from "./return"; import { transformBindingPattern } from "./variable-declaration"; @@ -114,7 +115,10 @@ export function transformFunctionBody( ): [lua.Statement[], Scope] { const scope = pushScope(context, ScopeType.Function); scope.node = node; - const bodyStatements = transformFunctionBodyContent(context, body); + let bodyStatements = transformFunctionBodyContent(context, body); + if (node && isAsyncFunction(node)) { + bodyStatements = wrapInAsyncAwaiter(context, bodyStatements); + } const headerStatements = transformFunctionBodyHeader(context, scope, parameters, spreadIdentifier); popScope(context); return [[...headerStatements, ...bodyStatements], scope]; @@ -195,6 +199,7 @@ export function transformFunctionToExpression( spreadIdentifier, node ); + const functionExpression = lua.createFunctionExpression( lua.createBlock(transformedBody), paramNames, diff --git a/src/transformation/visitors/index.ts b/src/transformation/visitors/index.ts index b05527364..28f965932 100644 --- a/src/transformation/visitors/index.ts +++ b/src/transformation/visitors/index.ts @@ -41,6 +41,7 @@ import { typescriptVisitors } from "./typescript"; import { transformPostfixUnaryExpression, transformPrefixUnaryExpression } from "./unary-expression"; import { transformVariableStatement } from "./variable-declaration"; import { jsxVisitors } from "./jsx/jsx"; +import { transformAwaitExpression } from "./async-await"; const transformEmptyStatement: FunctionVisitor = () => undefined; const transformParenthesizedExpression: FunctionVisitor = (node, context) => @@ -51,6 +52,7 @@ export const standardVisitors: Visitors = { ...typescriptVisitors, ...jsxVisitors, [ts.SyntaxKind.ArrowFunction]: transformFunctionLikeDeclaration, + [ts.SyntaxKind.AwaitExpression]: transformAwaitExpression, [ts.SyntaxKind.BinaryExpression]: transformBinaryExpression, [ts.SyntaxKind.Block]: transformBlock, [ts.SyntaxKind.BreakStatement]: transformBreakStatement, diff --git a/src/transformation/visitors/sourceFile.ts b/src/transformation/visitors/sourceFile.ts index 604aa0c7e..41561c67f 100644 --- a/src/transformation/visitors/sourceFile.ts +++ b/src/transformation/visitors/sourceFile.ts @@ -23,6 +23,7 @@ export const transformSourceFileNode: FunctionVisitor = (node, co } } else { pushScope(context, ScopeType.File); + statements = performHoisting(context, context.transformStatements(node.statements)); popScope(context); diff --git a/src/transformation/visitors/spread.ts b/src/transformation/visitors/spread.ts index 6e6ec391a..f8ebdc79d 100644 --- a/src/transformation/visitors/spread.ts +++ b/src/transformation/visitors/spread.ts @@ -37,6 +37,11 @@ export function isOptimizedVarArgSpread(context: TransformationContext, symbol: return false; } + // Scope cannot be an async function + if (scope.node.modifiers?.some(m => m.kind === ts.SyntaxKind.AsyncKeyword)) { + return false; + } + // Identifier must be a vararg in the local function scope's parameters const isSpreadParameter = (p: ts.ParameterDeclaration) => p.dotDotDotToken && ts.isIdentifier(p.name) && context.checker.getSymbolAtLocation(p.name) === symbol; diff --git a/test/unit/builtins/async-await.spec.ts b/test/unit/builtins/async-await.spec.ts new file mode 100644 index 000000000..0a9c3513e --- /dev/null +++ b/test/unit/builtins/async-await.spec.ts @@ -0,0 +1,371 @@ +import { ModuleKind, ScriptTarget } from "typescript"; +import { awaitMustBeInAsyncFunction } from "../../../src/transformation/utils/diagnostics"; +import * as util from "../../util"; + +const promiseTestLib = ` +// Some logging utility, useful for asserting orders of operations + +const allLogs: any[] = []; +function log(...values: any[]) { + allLogs.push(...values); +} + +// Create a promise and store its resolve and reject functions, useful for creating pending promises + +function defer() { + let resolve: (data: any) => void = () => {}; + let reject: (reason: string) => void = () => {}; + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + return { promise, resolve, reject }; +}`; + +test("can await already resolved promise", () => { + util.testFunction` + const result = []; + async function abc() { + return await Promise.resolve(30); + } + abc().then(value => result.push(value)); + + return result; + `.expectToEqual([30]); +}); + +test("can await already rejected promise", () => { + util.testFunction` + const result = []; + async function abc() { + return await Promise.reject("test rejection"); + } + abc().catch(reason => result.push(reason)); + + return result; + `.expectToEqual(["test rejection"]); +}); + +test("can await pending promise", () => { + util.testFunction` + const { promise, resolve } = defer(); + promise.then(data => log("resolving original promise", data)); + + async function abc() { + return await promise; + } + + const awaitingPromise = abc(); + awaitingPromise.then(data => log("resolving awaiting promise", data)); + + resolve("resolved data"); + + return allLogs; + + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["resolving original promise", "resolved data", "resolving awaiting promise", "resolved data"]); +}); + +test("can await non-promise values", () => { + util.testFunction` + async function foo() { + return await "foo"; + } + + async function bar() { + return await { foo: await foo(), bar: "bar" }; + } + + async function baz() { + return (await bar()).foo + (await bar()).bar; + } + + const { state, value } = baz() as any; + return { state, value }; + `.expectToEqual({ + state: 1, // __TS__PromiseState.Fulfilled + value: "foobar", + }); +}); + +test.each(["async function abc() {", "const abc = async () => {"])( + "can return non-promise from async function (%p)", + functionHeader => { + util.testFunction` + const { promise, resolve } = defer(); + promise.then(data => log("resolving original promise", data)); + + ${functionHeader} + await promise; + return "abc return data" + } + + const awaitingPromise = abc(); + awaitingPromise.then(data => log("resolving awaiting promise", data)); + + resolve("resolved data"); + + return allLogs; + + ` + .setTsHeader(promiseTestLib) + .expectToEqual([ + "resolving original promise", + "resolved data", + "resolving awaiting promise", + "abc return data", + ]); + } +); + +test.each(["async function abc() {", "const abc = async () => {"])( + "can have multiple awaits in async function (%p)", + functionHeader => { + util.testFunction` + const { promise: promise1, resolve: resolve1 } = defer(); + const { promise: promise2, resolve: resolve2 } = defer(); + const { promise: promise3, resolve: resolve3 } = defer(); + promise1.then(data => log("resolving promise1", data)); + promise2.then(data => log("resolving promise2", data)); + promise3.then(data => log("resolving promise3", data)); + + ${functionHeader} + const result1 = await promise1; + const result2 = await promise2; + const result3 = await promise3; + return [result1, result2, result3]; + } + + const awaitingPromise = abc(); + awaitingPromise.then(data => log("resolving awaiting promise", data)); + + resolve1("data1"); + resolve2("data2"); + resolve3("data3"); + + return allLogs; + + ` + .setTsHeader(promiseTestLib) + .expectToEqual([ + "resolving promise1", + "data1", + "resolving promise2", + "data2", + "resolving promise3", + "data3", + "resolving awaiting promise", + ["data1", "data2", "data3"], + ]); + } +); + +test("can make async lambdas with expression body", () => { + util.testFunction` + const foo = async () => "foo"; + const bar = async () => await foo(); + + const { state, value } = bar() as any; + return { state, value }; + `.expectToEqual({ + state: 1, // __TS__PromiseState.Fulfilled + value: "foo", + }); +}); + +test("can await async function from async function", () => { + util.testFunction` + const { promise, resolve } = defer(); + promise.then(data => log("resolving original promise", data)); + + async function abc() { + return await promise; + } + + async function def() { + return await abc(); + } + + const awaitingPromise = def(); + awaitingPromise.then(data => log("resolving awaiting promise", data)); + + resolve("resolved data"); + + return allLogs; + + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["resolving original promise", "resolved data", "resolving awaiting promise", "resolved data"]); +}); + +test("async function returning value is same as non-async function returning promise", () => { + util.testFunction` + function f(): Promise { + return Promise.resolve(42); + } + + async function fAsync(): Promise { + return 42; + } + + const { state: state1, value: value1 } = f() as any; + const { state: state2, value: value2 } = fAsync() as any; + + return { + state1, value1, + state2, value2 + }; + `.expectToEqual({ + state1: 1, // __TS__PromiseState.Fulfilled + value1: 42, + state2: 1, // __TS__PromiseState.Fulfilled + value2: 42, + }); +}); + +test("correctly handles awaited functions rejecting", () => { + util.testFunction` + const { promise: promise1, reject } = defer(); + const { promise: promise2 } = defer(); + + promise1.then(data => log("resolving promise1", data), reason => log("rejecting promise1", reason)); + promise2.then(data => log("resolving promise2", data)); + + async function abc() { + const result1 = await promise1; + const result2 = await promise2; + return [result1, result2]; + } + + const awaitingPromise = abc(); + awaitingPromise.catch(reason => log("awaiting promise was rejected because:", reason)); + + reject("test reject"); + + return allLogs; + + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["rejecting promise1", "test reject", "awaiting promise was rejected because:", "test reject"]); +}); + +test("can call async function at top-level", () => { + util.testModule` + export let aStarted = false; + async function a() { + aStarted = true; + return 42; + } + + a(); // Call async function (but cannot await) + ` + .setOptions({ module: ModuleKind.ESNext, target: ScriptTarget.ES2017 }) + .expectToEqual({ + aStarted: true, + }); +}); + +test("async function throws error", () => { + util.testFunction` + async function a() { + throw "test throw"; + } + + const { state, rejectionReason } = a() as any; + return { state, rejectionReason }; + `.expectToEqual({ + state: 2, // __TS__PromiseState.Rejected + rejectionReason: "test throw", + }); +}); + +test("async lambda throws error", () => { + util.testFunction` + const a = async () => { + throw "test throw"; + } + + const { state, rejectionReason } = a() as any; + return { state, rejectionReason }; + `.expectToEqual({ + state: 2, // __TS__PromiseState.Rejected + rejectionReason: "test throw", + }); +}); + +test("async function throws object", () => { + util.testFunction` + async function a() { + throw new Error("test throw"); + } + + const { state, rejectionReason } = a() as any; + return { state, rejectionReason }; + `.expectToEqual({ + state: 2, // __TS__PromiseState.Rejected + rejectionReason: { + message: "test throw", + name: "Error", + stack: expect.stringContaining("stack traceback"), + }, + }); +}); + +test.each([ + "await a();", + "const b = await a();", + "export const b = await a();", + "declare function foo(n: number): number; foo(await a());", + "declare function foo(n: number): number; const b = foo(await a());", + "const b = [await a()];", + "const b = [4, await a()];", + "const b = true ? 4 : await a();", +])("cannot await at top-level (%p)", awaitUsage => { + util.testModule` + async function a() { + return 42; + } + + ${awaitUsage} + export {} // Required to make TS happy, cannot await without import/exports + ` + .setOptions({ module: ModuleKind.ESNext, target: ScriptTarget.ES2017 }) + .expectToHaveDiagnostics([awaitMustBeInAsyncFunction.code]); +}); + +test("async function can access varargs", () => { + util.testFunction` + const { promise, resolve } = defer(); + + async function a(...args: string[]) { + log(await promise); + log(args[1]); + } + + const awaitingPromise = a("A", "B", "C"); + resolve("resolved"); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["resolved", "B"]); +}); + +test("async function can forward varargs", () => { + util.testFunction` + const { promise, resolve } = defer(); + + async function a(...args: string[]) { + log(await promise); + log(...args); + } + + const awaitingPromise = a("A", "B", "C"); + resolve("resolved"); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["resolved", "A", "B", "C"]); +});