Skip to content

Commit e90fc31

Browse files
authored
Lua table extensions (#991)
* Lua table extensions * addressing feedback Co-authored-by: Tom <tomblind@users.noreply.github.com>
1 parent 0a5bd20 commit e90fc31

File tree

12 files changed

+429
-31
lines changed

12 files changed

+429
-31
lines changed

language-extensions/index.d.ts

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,3 +430,81 @@ declare type LuaLength<TOperand, TReturn> = ((operand: TOperand) => TReturn) & L
430430
* @param TReturn The resulting (return) type of the operation.
431431
*/
432432
declare type LuaLengthMethod<TReturn> = (() => TReturn) & LuaExtension<"__luaLengthMethodBrand">;
433+
434+
/**
435+
* Calls to functions with this type are translated to `table[key]`.
436+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
437+
*
438+
* @param TTable The type to access as a Lua table.
439+
* @param TKey The type of the key to use to access the table.
440+
* @param TValue The type of the value stored in the table.
441+
*/
442+
declare type LuaTableGet<TTable extends object, TKey extends {}, TValue> = ((table: TTable, key: TKey) => TValue) &
443+
LuaExtension<"__luaTableGetBrand">;
444+
445+
/**
446+
* Calls to methods with this type are translated to `table[key]`, where `table` is the object with the method.
447+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
448+
*
449+
* @param TKey The type of the key to use to access the table.
450+
* @param TValue The type of the value stored in the table.
451+
*/
452+
declare type LuaTableGetMethod<TKey extends {}, TValue> = ((key: TKey) => TValue) &
453+
LuaExtension<"__luaTableGetMethodBrand">;
454+
455+
/**
456+
* Calls to functions with this type are translated to `table[key] = value`.
457+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
458+
*
459+
* @param TTable The type to access as a Lua table.
460+
* @param TKey The type of the key to use to access the table.
461+
* @param TValue The type of the value to assign to the table.
462+
*/
463+
declare type LuaTableSet<TTable extends object, TKey extends {}, TValue> = ((
464+
table: TTable,
465+
key: TKey,
466+
value: TValue
467+
) => void) &
468+
LuaExtension<"__luaTableSetBrand">;
469+
470+
/**
471+
* Calls to methods with this type are translated to `table[key] = value`, where `table` is the object with the method.
472+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
473+
*
474+
* @param TKey The type of the key to use to access the table.
475+
* @param TValue The type of the value to assign to the table.
476+
*/
477+
declare type LuaTableSetMethod<TKey extends {}, TValue> = ((key: TKey, value: TValue) => void) &
478+
LuaExtension<"__luaTableSetMethodBrand">;
479+
480+
/**
481+
* A convenience type for working directly with a Lua table.
482+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
483+
*
484+
* @param TKey The type of the keys used to access the table.
485+
* @param TValue The type of the values stored in the table.
486+
*/
487+
declare interface LuaTable<TKey extends {} = {}, TValue = any> {
488+
length: LuaLengthMethod<number>;
489+
get: LuaTableGetMethod<TKey, TValue>;
490+
set: LuaTableSetMethod<TKey, TValue>;
491+
}
492+
493+
/**
494+
* A convenience type for working directly with a Lua table.
495+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
496+
*
497+
* @param TKey The type of the keys used to access the table.
498+
* @param TValue The type of the values stored in the table.
499+
*/
500+
declare type LuaTableConstructor = (new <TKey extends {} = {}, TValue = any>() => LuaTable<TKey, TValue>) &
501+
LuaExtension<"__luaTableNewBrand">;
502+
503+
/**
504+
* A convenience type for working directly with a Lua table.
505+
* For more information see: https://typescripttolua.github.io/docs/advanced/language-extensions
506+
*
507+
* @param TKey The type of the keys used to access the table.
508+
* @param TValue The type of the values stored in the table.
509+
*/
510+
declare const LuaTable: LuaTableConstructor;

src/transformation/utils/diagnostics.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,14 @@ export const invalidOperatorMappingUse = createErrorDiagnosticFactory(
174174
"This function must always be directly called and cannot be referred to."
175175
);
176176

177+
export const invalidTableExtensionUse = createErrorDiagnosticFactory(
178+
"This function must be called directly and cannot be referred to."
179+
);
180+
181+
export const invalidTableSetExpression = createErrorDiagnosticFactory(
182+
"Table set extension can only be called as a stand-alone statement. It cannot be used as an expression in another statement."
183+
);
184+
177185
export const annotationDeprecated = createWarningDiagnosticFactory(
178186
(kind: AnnotationKind) =>
179187
`'@${kind}' is deprecated and will be removed in a future update. Please update your code before upgrading to the next release, otherwise your project will no longer compile. ` +

src/transformation/utils/language-extensions.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ export enum ExtensionKind {
4242
BitwiseNotOperatorMethodType = "BitwiseNotOperatorMethodType",
4343
LengthOperatorType = "LengthOperatorType",
4444
LengthOperatorMethodType = "LengthOperatorMethodType",
45+
TableNewType = "TableNewType",
46+
TableGetType = "TableGetType",
47+
TableGetMethodType = "TableGetMethodType",
48+
TableSetType = "TableSetType",
49+
TableSetMethodType = "TableSetMethodType",
4550
}
4651

4752
const extensionKindToFunctionName: { [T in ExtensionKind]?: string } = {
@@ -90,6 +95,11 @@ const extensionKindToTypeBrand: { [T in ExtensionKind]: string } = {
9095
[ExtensionKind.BitwiseNotOperatorMethodType]: "__luaBitwiseNotMethodBrand",
9196
[ExtensionKind.LengthOperatorType]: "__luaLengthBrand",
9297
[ExtensionKind.LengthOperatorMethodType]: "__luaLengthMethodBrand",
98+
[ExtensionKind.TableNewType]: "__luaTableNewBrand",
99+
[ExtensionKind.TableGetType]: "__luaTableGetBrand",
100+
[ExtensionKind.TableGetMethodType]: "__luaTableGetMethodBrand",
101+
[ExtensionKind.TableSetType]: "__luaTableSetBrand",
102+
[ExtensionKind.TableSetMethodType]: "__luaTableSetMethodBrand",
93103
};
94104

95105
export function isExtensionType(type: ts.Type, extensionKind: ExtensionKind): boolean {

src/transformation/utils/typescript/index.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,15 @@ export function getAllCallSignatures(type: ts.Type): readonly ts.Signature[] {
7979
export function isExpressionWithEvaluationEffect(node: ts.Expression): boolean {
8080
return !(ts.isLiteralExpression(node) || ts.isIdentifier(node) || node.kind === ts.SyntaxKind.ThisKeyword);
8181
}
82+
83+
export function getFunctionTypeForCall(context: TransformationContext, node: ts.CallExpression) {
84+
const signature = context.checker.getResolvedSignature(node);
85+
if (!signature || !signature.declaration) {
86+
return;
87+
}
88+
const typeDeclaration = findFirstNodeAbove(signature.declaration, ts.isTypeAliasDeclaration);
89+
if (!typeDeclaration) {
90+
return;
91+
}
92+
return context.checker.getTypeFromTypeNode(typeDeclaration.type);
93+
}

src/transformation/visitors/call.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ import { transformElementAccessArgument } from "./access";
1313
import { transformLuaTableCallExpression } from "./lua-table";
1414
import { shouldMultiReturnCallBeWrapped, returnsMultiType } from "./language-extensions/multi";
1515
import { isOperatorMapping, transformOperatorMappingExpression } from "./language-extensions/operators";
16+
import {
17+
isTableGetCall,
18+
isTableSetCall,
19+
transformTableGetExpression,
20+
transformTableSetExpression,
21+
} from "./language-extensions/table";
22+
import { invalidTableSetExpression } from "../utils/diagnostics";
1623

1724
export type PropertyCallExpression = ts.CallExpression & { expression: ts.PropertyAccessExpression };
1825

@@ -214,6 +221,18 @@ export const transformCallExpression: FunctionVisitor<ts.CallExpression> = (node
214221
return transformOperatorMappingExpression(context, node);
215222
}
216223

224+
if (isTableGetCall(context, node)) {
225+
return transformTableGetExpression(context, node);
226+
}
227+
228+
if (isTableSetCall(context, node)) {
229+
context.diagnostics.push(invalidTableSetExpression(node));
230+
return createImmediatelyInvokedFunctionExpression(
231+
[transformTableSetExpression(context, node)],
232+
lua.createNilLiteral()
233+
);
234+
}
235+
217236
if (ts.isPropertyAccessExpression(node.expression)) {
218237
const result = transformPropertyCall(context, node as PropertyCallExpression);
219238
return wrapResult ? wrapInTable(result) : result;

src/transformation/visitors/class/new.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { AnnotationKind, getTypeAnnotations } from "../../utils/annotations";
55
import { annotationInvalidArgumentCount, extensionCannotConstruct } from "../../utils/diagnostics";
66
import { importLuaLibFeature, LuaLibFeature, transformLuaLibFunction } from "../../utils/lualib";
77
import { transformArguments } from "../call";
8+
import { isTableNewCall } from "../language-extensions/table";
89
import { transformLuaTableNewExpression } from "../lua-table";
910

1011
const builtinErrorTypeNames = new Set([
@@ -53,6 +54,10 @@ export const transformNewExpression: FunctionVisitor<ts.NewExpression> = (node,
5354
return luaTableResult;
5455
}
5556

57+
if (isTableNewCall(context, node)) {
58+
return lua.createTableExpression(undefined, node);
59+
}
60+
5661
const name = context.transformExpression(node.expression);
5762
const signature = context.checker.getResolvedSignature(node);
5863
const params = node.arguments

src/transformation/visitors/expression-statement.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import * as ts from "typescript";
22
import * as lua from "../../LuaAST";
33
import { FunctionVisitor } from "../context";
44
import { transformBinaryExpressionStatement } from "./binary-expression";
5+
import { isTableSetCall, transformTableSetExpression } from "./language-extensions/table";
56
import { transformLuaTableExpressionStatement } from "./lua-table";
67
import { transformUnaryExpressionStatement } from "./unary-expression";
78

@@ -11,6 +12,10 @@ export const transformExpressionStatement: FunctionVisitor<ts.ExpressionStatemen
1112
return luaTableResult;
1213
}
1314

15+
if (ts.isCallExpression(node.expression) && isTableSetCall(context, node.expression)) {
16+
return transformTableSetExpression(context, node.expression);
17+
}
18+
1419
const unaryExpressionResult = transformUnaryExpressionStatement(context, node);
1520
if (unaryExpressionResult) {
1621
return unaryExpressionResult;

src/transformation/visitors/identifier.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
invalidMultiFunctionUse,
99
invalidOperatorMappingUse,
1010
invalidRangeUse,
11+
invalidTableExtensionUse,
1112
} from "../utils/diagnostics";
1213
import { createExportedIdentifier, getSymbolExportScope } from "../utils/export";
1314
import { createSafeName, hasUnsafeIdentifierName } from "../utils/safe-names";
@@ -16,6 +17,7 @@ import { findFirstNodeAbove } from "../utils/typescript";
1617
import { isMultiFunctionNode } from "./language-extensions/multi";
1718
import { isOperatorMapping } from "./language-extensions/operators";
1819
import { isRangeFunctionNode } from "./language-extensions/range";
20+
import { isTableExtensionIdentifier } from "./language-extensions/table";
1921

2022
export function transformIdentifier(context: TransformationContext, identifier: ts.Identifier): lua.Identifier {
2123
if (isMultiFunctionNode(context, identifier)) {
@@ -27,6 +29,10 @@ export function transformIdentifier(context: TransformationContext, identifier:
2729
context.diagnostics.push(invalidOperatorMappingUse(identifier));
2830
}
2931

32+
if (isTableExtensionIdentifier(context, identifier)) {
33+
context.diagnostics.push(invalidTableExtensionUse(identifier));
34+
}
35+
3036
if (isRangeFunctionNode(context, identifier)) {
3137
context.diagnostics.push(invalidRangeUse(identifier));
3238
return lua.createAnonymousIdentifier(identifier);

src/transformation/visitors/language-extensions/operators.ts

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import * as lua from "../../../LuaAST";
33
import { TransformationContext } from "../../context";
44
import * as extensions from "../../utils/language-extensions";
55
import { assert } from "../../../utils";
6-
import { findFirstNodeAbove } from "../../utils/typescript";
6+
import { getFunctionTypeForCall } from "../../utils/typescript";
77
import { LuaTarget } from "../../../CompilerOptions";
88
import { unsupportedForTarget } from "../../utils/diagnostics";
99

@@ -66,43 +66,17 @@ const bitwiseOperatorMapExtensions = new Set<extensions.ExtensionKind>([
6666
extensions.ExtensionKind.BitwiseNotOperatorMethodType,
6767
]);
6868

69-
function getTypeDeclaration(declaration: ts.Declaration) {
70-
return ts.isTypeAliasDeclaration(declaration)
71-
? declaration
72-
: findFirstNodeAbove(declaration, ts.isTypeAliasDeclaration);
73-
}
74-
7569
function getOperatorMapExtensionKindForCall(context: TransformationContext, node: ts.CallExpression) {
76-
const signature = context.checker.getResolvedSignature(node);
77-
if (!signature || !signature.declaration) {
78-
return;
79-
}
80-
const typeDeclaration = getTypeDeclaration(signature.declaration);
81-
if (!typeDeclaration) {
82-
return;
83-
}
84-
const type = context.checker.getTypeFromTypeNode(typeDeclaration.type);
85-
return operatorMapExtensions.find(extensionKind => extensions.isExtensionType(type, extensionKind));
86-
}
87-
88-
function isOperatorMapType(context: TransformationContext, type: ts.Type): boolean {
89-
if (type.isUnionOrIntersection()) {
90-
return type.types.some(t => isOperatorMapType(context, t));
91-
} else {
92-
return operatorMapExtensions.some(extensionKind => extensions.isExtensionType(type, extensionKind));
93-
}
94-
}
95-
96-
function isOperatorMapIdentifier(context: TransformationContext, node: ts.Identifier) {
97-
const type = context.checker.getTypeAtLocation(node);
98-
return isOperatorMapType(context, type);
70+
const type = getFunctionTypeForCall(context, node);
71+
return type && operatorMapExtensions.find(extensionKind => extensions.isExtensionType(type, extensionKind));
9972
}
10073

10174
export function isOperatorMapping(context: TransformationContext, node: ts.CallExpression | ts.Identifier) {
10275
if (ts.isCallExpression(node)) {
10376
return getOperatorMapExtensionKindForCall(context, node) !== undefined;
10477
} else {
105-
return isOperatorMapIdentifier(context, node);
78+
const type = context.checker.getTypeAtLocation(node);
79+
return operatorMapExtensions.some(extensionKind => extensions.isExtensionType(type, extensionKind));
10680
}
10781
}
10882

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import * as ts from "typescript";
2+
import * as lua from "../../../LuaAST";
3+
import { TransformationContext } from "../../context";
4+
import * as extensions from "../../utils/language-extensions";
5+
import { getFunctionTypeForCall } from "../../utils/typescript";
6+
import { assert } from "../../../utils";
7+
8+
const tableGetExtensions = [extensions.ExtensionKind.TableGetType, extensions.ExtensionKind.TableGetMethodType];
9+
10+
const tableSetExtensions = [extensions.ExtensionKind.TableSetType, extensions.ExtensionKind.TableSetMethodType];
11+
12+
const tableExtensions = [extensions.ExtensionKind.TableNewType, ...tableGetExtensions, ...tableSetExtensions];
13+
14+
function getTableExtensionKindForCall(
15+
context: TransformationContext,
16+
node: ts.CallExpression,
17+
validExtensions: extensions.ExtensionKind[]
18+
) {
19+
const type = getFunctionTypeForCall(context, node);
20+
return type && validExtensions.find(extensionKind => extensions.isExtensionType(type, extensionKind));
21+
}
22+
23+
export function isTableExtensionIdentifier(context: TransformationContext, node: ts.Identifier) {
24+
const type = context.checker.getTypeAtLocation(node);
25+
return tableExtensions.some(extensionKind => extensions.isExtensionType(type, extensionKind));
26+
}
27+
28+
export function isTableGetCall(context: TransformationContext, node: ts.CallExpression) {
29+
return getTableExtensionKindForCall(context, node, tableGetExtensions) !== undefined;
30+
}
31+
32+
export function isTableSetCall(context: TransformationContext, node: ts.CallExpression) {
33+
return getTableExtensionKindForCall(context, node, tableSetExtensions) !== undefined;
34+
}
35+
36+
export function isTableNewCall(context: TransformationContext, node: ts.NewExpression) {
37+
const type = context.checker.getTypeAtLocation(node.expression);
38+
return extensions.isExtensionType(type, extensions.ExtensionKind.TableNewType);
39+
}
40+
41+
export function transformTableGetExpression(context: TransformationContext, node: ts.CallExpression): lua.Expression {
42+
const extensionKind = getTableExtensionKindForCall(context, node, tableGetExtensions);
43+
assert(extensionKind);
44+
45+
const args = node.arguments.slice();
46+
if (
47+
args.length === 1 &&
48+
(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))
49+
) {
50+
args.unshift(node.expression.expression);
51+
}
52+
53+
return lua.createTableIndexExpression(
54+
context.transformExpression(args[0]),
55+
context.transformExpression(args[1]),
56+
node
57+
);
58+
}
59+
60+
export function transformTableSetExpression(context: TransformationContext, node: ts.CallExpression): lua.Statement {
61+
const extensionKind = getTableExtensionKindForCall(context, node, tableSetExtensions);
62+
assert(extensionKind);
63+
64+
const args = node.arguments.slice();
65+
if (
66+
args.length === 2 &&
67+
(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))
68+
) {
69+
args.unshift(node.expression.expression);
70+
}
71+
72+
return lua.createAssignmentStatement(
73+
lua.createTableIndexExpression(
74+
context.transformExpression(args[0]),
75+
context.transformExpression(args[1]),
76+
node
77+
),
78+
context.transformExpression(args[2])
79+
);
80+
}

0 commit comments

Comments
 (0)