Skip to content

Commit c955a63

Browse files
authored
implemented support for properties on functions (#1180)
Co-authored-by: Tom <tomblind@users.noreply.github.com>
1 parent 1f3a505 commit c955a63

File tree

9 files changed

+264
-18
lines changed

9 files changed

+264
-18
lines changed

src/LuaLib.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ const luaLibDependencies: Partial<Record<LuaLibFeature, LuaLibFeature[]>> = {
103103
ArrayConcat: [LuaLibFeature.ArrayIsArray],
104104
ArrayFlat: [LuaLibFeature.ArrayConcat, LuaLibFeature.ArrayIsArray],
105105
ArrayFlatMap: [LuaLibFeature.ArrayConcat, LuaLibFeature.ArrayIsArray],
106-
Await: [LuaLibFeature.InstanceOf, LuaLibFeature.New],
106+
Await: [LuaLibFeature.InstanceOf, LuaLibFeature.New, LuaLibFeature.Promise],
107107
Decorate: [LuaLibFeature.ObjectGetOwnPropertyDescriptor, LuaLibFeature.SetDescriptor, LuaLibFeature.ObjectAssign],
108108
DelegatedYield: [LuaLibFeature.StringAccess],
109109
Delete: [LuaLibFeature.ObjectGetOwnPropertyDescriptors, LuaLibFeature.Error, LuaLibFeature.New],

src/transformation/builtins/function.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ export function transformFunctionPrototypeCall(
3232
return transformLuaLibFunction(context, LuaLibFeature.FunctionBind, node, caller, ...params);
3333
case "call":
3434
return lua.createCallExpression(caller, params, node);
35-
default:
35+
case "toString":
3636
context.diagnostics.push(unsupportedProperty(expression.name, "function", expressionName));
3737
}
3838
}
@@ -60,7 +60,10 @@ export function transformFunctionProperty(
6060
? lua.createBinaryExpression(nparams, lua.createNumericLiteral(1), lua.SyntaxKind.SubtractionOperator)
6161
: nparams;
6262

63-
default:
63+
case "arguments":
64+
case "caller":
65+
case "displayName":
66+
case "name":
6467
context.diagnostics.push(unsupportedProperty(node.name, "function", node.name.text));
6568
}
6669
}

src/transformation/builtins/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ export function transformBuiltinPropertyAccessExpression(
4141
return transformArrayProperty(context, node);
4242
}
4343

44-
if (isFunctionType(context, ownerType)) {
44+
if (isFunctionType(ownerType)) {
4545
return transformFunctionProperty(context, node);
4646
}
4747

@@ -131,7 +131,7 @@ export function transformBuiltinCallExpression(
131131
return transformArrayPrototypeCall(context, node);
132132
}
133133

134-
if (isFunctionType(context, ownerType) && hasStandardLibrarySignature(context, node)) {
134+
if (isFunctionType(ownerType) && hasStandardLibrarySignature(context, node)) {
135135
if (isOptionalCall) return unsupportedOptionalCall();
136136
return transformFunctionPrototypeCall(context, node);
137137
}

src/transformation/utils/language-extensions.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ export function isExtensionType(type: ts.Type, extensionKind: ExtensionKind): bo
118118
return typeBrand !== undefined && type.getProperty(typeBrand) !== undefined;
119119
}
120120

121+
export function getExtensionKinds(type: ts.Type): ExtensionKind[] {
122+
return (Object.keys(extensionKindToTypeBrand) as Array<keyof typeof extensionKindToTypeBrand>).filter(
123+
e => type.getProperty(extensionKindToTypeBrand[e]) !== undefined
124+
);
125+
}
126+
121127
export function isExtensionValue(
122128
context: TransformationContext,
123129
symbol: ts.Symbol,

src/transformation/utils/lua-ast.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ export function createLocalOrExportedOrGlobalDeclaration(
157157
const isTopLevelVariable = scope.type === ScopeType.File;
158158

159159
if (context.isModule || !isTopLevelVariable) {
160-
if (!isFunctionDeclaration && hasMultipleReferences(scope, lhs)) {
160+
const isLuaFunctionExpression = rhs && !Array.isArray(rhs) && lua.isFunctionExpression(rhs);
161+
const isSafeRecursiveFunctionDeclaration = isFunctionDeclaration && isLuaFunctionExpression;
162+
if (!isSafeRecursiveFunctionDeclaration && hasMultipleReferences(scope, lhs)) {
161163
// Split declaration and assignment of identifiers that reference themselves in their declaration.
162164
// Put declaration above preceding statements in case the identifier is referenced in those.
163165
const precedingDeclaration = lua.createVariableDeclarationStatement(lhs, undefined, tsOriginal);
@@ -166,10 +168,8 @@ export function createLocalOrExportedOrGlobalDeclaration(
166168
assignment = lua.createAssignmentStatement(lhs, rhs, tsOriginal);
167169
}
168170

169-
if (!isFunctionDeclaration) {
170-
// Remember local variable declarations for hoisting later
171-
addScopeVariableDeclaration(scope, precedingDeclaration);
172-
}
171+
// Remember local variable declarations for hoisting later
172+
addScopeVariableDeclaration(scope, precedingDeclaration);
173173
} else {
174174
declaration = lua.createVariableDeclarationStatement(lhs, rhs, tsOriginal);
175175

src/transformation/utils/typescript/types.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ export function isArrayType(context: TransformationContext, type: ts.Type): bool
107107
return forTypeOrAnySupertype(context, type, t => isExplicitArrayType(context, t));
108108
}
109109

110-
export function isFunctionType(context: TransformationContext, type: ts.Type): boolean {
111-
const typeNode = context.checker.typeToTypeNode(type, undefined, ts.NodeBuilderFlags.InTypeAlias);
112-
return typeNode !== undefined && ts.isFunctionTypeNode(typeNode);
110+
export function isFunctionType(type: ts.Type): boolean {
111+
return type.getCallSignatures().length > 0;
113112
}

src/transformation/visitors/function.ts

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { AnnotationKind, getNodeAnnotations } from "../utils/annotations";
66
import { annotationRemoved } from "../utils/diagnostics";
77
import { createDefaultExportStringLiteral, hasDefaultExportModifier } from "../utils/export";
88
import { ContextType, getFunctionContextType } from "../utils/function-context";
9+
import { getExtensionKinds } from "../utils/language-extensions";
910
import {
1011
createExportsIdentifier,
1112
createLocalOrExportedOrGlobalDeclaration,
@@ -15,6 +16,7 @@ import {
1516
import { LuaLibFeature, transformLuaLibFunction } from "../utils/lualib";
1617
import { transformInPrecedingStatementScope } from "../utils/preceding-statements";
1718
import { peekScope, performHoisting, popScope, pushScope, Scope, ScopeType } from "../utils/scope";
19+
import { isFunctionType } from "../utils/typescript";
1820
import { isAsyncFunction, wrapInAsyncAwaiter } from "./async-await";
1921
import { transformIdentifier } from "./identifier";
2022
import { transformExpressionBodyToReturnStatement } from "./return";
@@ -51,6 +53,42 @@ function isRestParameterReferenced(identifier: lua.Identifier, scope: Scope): bo
5153
return references !== undefined && references.length > 0;
5254
}
5355

56+
export function createCallableTable(functionExpression: lua.Expression): lua.Expression {
57+
// __call metamethod receives the table as the first argument, so we need to add a dummy parameter
58+
if (lua.isFunctionExpression(functionExpression)) {
59+
functionExpression.params?.unshift(lua.createAnonymousIdentifier());
60+
} else {
61+
// functionExpression may have been replaced (lib functions, etc...),
62+
// so we create a forwarding function to eat the extra argument
63+
functionExpression = lua.createFunctionExpression(
64+
lua.createBlock([
65+
lua.createReturnStatement([lua.createCallExpression(functionExpression, [lua.createDotsLiteral()])]),
66+
]),
67+
[lua.createAnonymousIdentifier()],
68+
lua.createDotsLiteral(),
69+
lua.FunctionExpressionFlags.Inline
70+
);
71+
}
72+
return lua.createCallExpression(lua.createIdentifier("setmetatable"), [
73+
lua.createTableExpression(),
74+
lua.createTableExpression([
75+
lua.createTableFieldExpression(functionExpression, lua.createStringLiteral("__call")),
76+
]),
77+
]);
78+
}
79+
80+
export function isFunctionTypeWithProperties(functionType: ts.Type) {
81+
if (functionType.isUnion()) {
82+
return functionType.types.some(isFunctionTypeWithProperties);
83+
} else {
84+
return (
85+
isFunctionType(functionType) &&
86+
functionType.getProperties().length > 0 &&
87+
getExtensionKinds(functionType).length === 0 // ignore TSTL extension functions like $range
88+
);
89+
}
90+
}
91+
5492
export function transformFunctionBodyContent(context: TransformationContext, body: ts.ConciseBody): lua.Statement[] {
5593
if (!ts.isBlock(body)) {
5694
const [precedingStatements, returnStatement] = transformInPrecedingStatementScope(context, () =>
@@ -251,9 +289,16 @@ export function transformFunctionLikeDeclaration(
251289
// Only handle if the name is actually referenced inside the function
252290
if (isReferenced) {
253291
const nameIdentifier = transformIdentifier(context, node.name);
254-
context.addPrecedingStatements(
255-
lua.createVariableDeclarationStatement(nameIdentifier, functionExpression)
256-
);
292+
if (isFunctionTypeWithProperties(context.checker.getTypeAtLocation(node))) {
293+
context.addPrecedingStatements([
294+
lua.createVariableDeclarationStatement(nameIdentifier),
295+
lua.createAssignmentStatement(nameIdentifier, createCallableTable(functionExpression)),
296+
]);
297+
} else {
298+
context.addPrecedingStatements(
299+
lua.createVariableDeclarationStatement(nameIdentifier, functionExpression)
300+
);
301+
}
257302
return lua.cloneIdentifier(nameIdentifier);
258303
}
259304
}
@@ -295,7 +340,13 @@ export const transformFunctionDeclaration: FunctionVisitor<ts.FunctionDeclaratio
295340
scope.functionDefinitions.set(name.symbolId, functionInfo);
296341
}
297342

298-
return createLocalOrExportedOrGlobalDeclaration(context, name, functionExpression, node);
343+
// Wrap functions with properties into a callable table
344+
const wrappedFunction =
345+
node.name && isFunctionTypeWithProperties(context.checker.getTypeAtLocation(node.name))
346+
? createCallableTable(functionExpression)
347+
: functionExpression;
348+
349+
return createLocalOrExportedOrGlobalDeclaration(context, name, wrappedFunction, node);
299350
};
300351

301352
export const transformYieldExpression: FunctionVisitor<ts.YieldExpression> = (expression, context) => {

src/transformation/visitors/variable-declaration.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { addExportToIdentifier } from "../utils/export";
88
import { createLocalOrExportedOrGlobalDeclaration, createUnpackCall, wrapInTable } from "../utils/lua-ast";
99
import { LuaLibFeature, transformLuaLibFunction } from "../utils/lualib";
1010
import { transformInPrecedingStatementScope } from "../utils/preceding-statements";
11+
import { createCallableTable, isFunctionTypeWithProperties } from "./function";
1112
import { transformIdentifier } from "./identifier";
1213
import { isMultiReturnCall } from "./language-extensions/multi";
1314
import { transformPropertyName } from "./literal";
@@ -257,7 +258,18 @@ export function transformVariableDeclaration(
257258
// Find variable identifier
258259
const identifierName = transformIdentifier(context, statement.name);
259260
const value = statement.initializer && context.transformExpression(statement.initializer);
260-
return createLocalOrExportedOrGlobalDeclaration(context, identifierName, value, statement);
261+
262+
// Wrap functions being assigned to a type that contains additional properties in a callable table
263+
// This catches 'const foo = function() {}; foo.bar = "FOOBAR";'
264+
const wrappedValue =
265+
value &&
266+
// Skip named function expressions because they will have been wrapped already
267+
!(statement.initializer && ts.isFunctionExpression(statement.initializer) && statement.initializer.name) &&
268+
isFunctionTypeWithProperties(context.checker.getTypeAtLocation(statement.name))
269+
? createCallableTable(value)
270+
: value;
271+
272+
return createLocalOrExportedOrGlobalDeclaration(context, identifierName, wrappedValue, statement);
261273
} else if (ts.isArrayBindingPattern(statement.name) || ts.isObjectBindingPattern(statement.name)) {
262274
return transformBindingVariableDeclaration(context, statement.name, statement.initializer);
263275
} else {
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import * as util from "../../util";
2+
3+
test("property on function", () => {
4+
util.testFunction`
5+
function foo(s: string) { return s; }
6+
foo.bar = "bar";
7+
return foo("foo") + foo.bar;
8+
`.expectToMatchJsResult();
9+
});
10+
11+
test("property on void function", () => {
12+
util.testFunction`
13+
function foo(this: void, s: string) { return s; }
14+
foo.bar = "bar";
15+
return foo("foo") + foo.bar;
16+
`.expectToMatchJsResult();
17+
});
18+
19+
test("property on recursively referenced function", () => {
20+
util.testFunction`
21+
function foo(s: string) { return s + foo.bar; }
22+
foo.bar = "bar";
23+
return foo("foo") + foo.bar;
24+
`.expectToMatchJsResult();
25+
});
26+
27+
test("property on hoisted function", () => {
28+
util.testFunction`
29+
foo.bar = "bar";
30+
function foo(s: string) { return s; }
31+
return foo("foo") + foo.bar;
32+
`.expectToMatchJsResult();
33+
});
34+
35+
test("function merged with namespace", () => {
36+
util.testModule`
37+
function foo(s: string) { return s; }
38+
namespace foo {
39+
export let bar = "bar";
40+
}
41+
export const result = foo("foo") + foo.bar;
42+
`
43+
.setReturnExport("result")
44+
.expectToEqual("foobar");
45+
});
46+
47+
test("function with property assigned to variable", () => {
48+
util.testFunction`
49+
const foo = function(s: string) { return s; };
50+
foo.bar = "bar";
51+
return foo("foo") + foo.bar;
52+
`.expectToMatchJsResult();
53+
});
54+
55+
test("void function with property assigned to variable", () => {
56+
util.testFunction`
57+
const foo = function(this: void, s: string) { return s; };
58+
foo.bar = "bar";
59+
return foo("foo") + foo.bar;
60+
`.expectToMatchJsResult();
61+
});
62+
63+
test("recursively referenced function with property assigned to variable", () => {
64+
util.testFunction`
65+
const foo = function(s: string) { return s + foo.bar; };
66+
foo.bar = "bar";
67+
return foo("foo") + foo.bar;
68+
`.expectToMatchJsResult();
69+
});
70+
71+
test("named recursively referenced function with property assigned to variable", () => {
72+
util.testFunction`
73+
const foo = function baz(s: string) { return s + foo.bar + baz.bar; };
74+
foo.bar = "bar";
75+
return foo("foo") + foo.bar;
76+
`.expectToMatchJsResult();
77+
});
78+
79+
test("arrow function with property assigned to variable", () => {
80+
util.testFunction`
81+
const foo: { (s: string): string; bar: string; } = s => s;
82+
foo.bar = "bar";
83+
return foo("foo") + foo.bar;
84+
`.expectToMatchJsResult();
85+
});
86+
87+
test("void arrow function with property assigned to variable", () => {
88+
util.testFunction`
89+
const foo: { (this: void, s: string): string; bar: string; } = s => s;
90+
foo.bar = "bar";
91+
return foo("foo") + foo.bar;
92+
`.expectToMatchJsResult();
93+
});
94+
95+
test("recursively referenced arrow function with property assigned to variable", () => {
96+
util.testFunction`
97+
const foo: { (s: string): string; bar: string; } = s => s + foo.bar;
98+
foo.bar = "bar";
99+
return foo("foo") + foo.bar;
100+
`.expectToMatchJsResult();
101+
});
102+
103+
test("property on generator function", () => {
104+
util.testFunction`
105+
function *foo(s: string) { yield s; }
106+
foo.bar = "bar";
107+
for (const s of foo("foo")) {
108+
return s + foo.bar;
109+
}
110+
`.expectToMatchJsResult();
111+
});
112+
113+
test("generator function assigned to variable", () => {
114+
util.testFunction`
115+
const foo = function *(s: string) { yield s; }
116+
foo.bar = "bar";
117+
for (const s of foo("foo")) {
118+
return s + foo.bar;
119+
}
120+
`.expectToMatchJsResult();
121+
});
122+
123+
test("property on async function", () => {
124+
util.testFunction`
125+
let result = "";
126+
async function foo(s: string) { result = s + foo.bar; }
127+
foo.bar = "bar";
128+
void foo("foo");
129+
return result;
130+
`.expectToMatchJsResult();
131+
});
132+
133+
test("async function with property assigned to variable", () => {
134+
util.testFunction`
135+
let result = "";
136+
const foo = async function(s: string) { result = s + foo.bar; }
137+
foo.bar = "bar";
138+
void foo("foo");
139+
return result;
140+
`.expectToMatchJsResult();
141+
});
142+
143+
test("async arrow function with property assigned to variable", () => {
144+
util.testFunction`
145+
let result = "";
146+
const foo: { (s: string): Promise<void>; bar: string; } = async s => { result = s + foo.bar; };
147+
foo.bar = "bar";
148+
void foo("foo");
149+
return result;
150+
`.expectToMatchJsResult();
151+
});
152+
153+
test("call function with property using call method", () => {
154+
util.testFunction`
155+
function foo(s: string) { return this + s; }
156+
foo.baz = "baz";
157+
return foo.call("foo", "bar") + foo.baz;
158+
`.expectToMatchJsResult();
159+
});
160+
161+
test("call function with property using apply method", () => {
162+
util.testFunction`
163+
function foo(s: string) { return this + s; }
164+
foo.baz = "baz";
165+
return foo.apply("foo", ["bar"]) + foo.baz;
166+
`.expectToMatchJsResult();
167+
});
168+
169+
test("call function with property using bind method", () => {
170+
util.testFunction`
171+
function foo(s: string) { return this + s; }
172+
foo.baz = "baz";
173+
return foo.bind("foo", "bar")() + foo.baz;
174+
`.expectToMatchJsResult();
175+
});

0 commit comments

Comments
 (0)