-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathsplitter.py
More file actions
executable file
·384 lines (335 loc) · 12.7 KB
/
splitter.py
File metadata and controls
executable file
·384 lines (335 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
'''
Split the flow-graph to allow tests to dominate all parts of the code that depends on them.
We split on `if`s and `try`s. Either because of several tests on the same condition or
subsequent tests on a constant determined by the first condition.
E.g.
if a:
A
B
if a:
C
becomes
if a:
A
B
C
else:
B
ensuring that A dominates C.
or...
try:
import foo
except:
foo = None
X
if foo:
Y
becomes
try:
import foo
X
Y
except:
foo = None
X
To split on CFG node N we require that there exists nodes H1..Hn and N2 such that:
N and N2 are tests or conditional assignments to the same variable.
N dominates H1 .. Hn and N2
There is no assignment to the variable between N and N2
H1..Hn are the "split heads" of N, that is:
if N is a test, H1 and H2 are its true and false successors (there is no H3).
if N is a `try` then H1 .. Hn-1 are exists from the try body and Hn is the CFG node for the first (and only) `except` statement.
Within the region strictly dominated by N, N2 must reachable from all of H1..Hn
For simplicity we limit n (as in Hn) to 2, but that is not required for correctness.
'''
from collections import defaultdict
from semmle.python import ast
from semmle.python.passes.ast_pass import iter_fields
from operator import itemgetter
from semmle.graph import FlowGraph
MAX_SPLITS = 2
def do_split(ast_root, graph: FlowGraph):
'''Split the flow graph, using the AST to determine split points.'''
ast_labels = label_ast(ast_root)
cfg_labels = label_cfg(graph, ast_labels)
split_points = choose_split_points(graph, cfg_labels)
graph.split(split_points)
class ScopedAstLabellingVisitor(object):
'''Visitor for labelling AST nodes in scope.
Does not visit nodes belonging to inner scopes (methods, etc)
'''
def __init__(self, labels):
self.labels = labels
self.priority = 0
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
getattr(self, method, self.generic_visit)(node)
def generic_visit(self, node):
if isinstance(node, ast.AstBase):
for _, _, value in iter_fields(node):
self.visit(value)
def visit_Class(self, node):
#Do not visit sub-scopes
return
visit_Function = visit_Class
def visit_list(self, the_list):
for item in the_list:
method = 'visit_' + item.__class__.__name__
getattr(self, method, self.generic_visit)(item)
#Helper methods
@staticmethod
def get_variable(expr):
'''Returns the variable of this expr. Returns None if no variable.'''
if hasattr(expr, "variable"):
return expr.variable
else:
return None
@staticmethod
def is_const(expr):
if isinstance(expr, ast.Name):
return expr.variable.id in ("None", "True", "False")
elif isinstance(expr, ast.UnaryOp):
return ScopedAstLabellingVisitor.is_const(expr.operand)
return isinstance(expr, (ast.Num, ast.Str))
class AstLabeller(ScopedAstLabellingVisitor):
'''Visitor to label tests and assignments
for later scanning to determine split points.
'''
def __init__(self, *args):
ScopedAstLabellingVisitor.__init__(self, *args)
self.in_test = 0
def _label_for_compare(self, cmp):
if len(cmp.ops) != 1:
return None
var = self.get_variable(cmp.left)
if var is None:
var = self.get_variable(cmp.comparators[0])
k = cmp.left
else:
k = cmp.comparators[0]
if var is not None and self.is_const(k):
self.priority += 1
return (var, k, self.priority)
return None
def visit_Compare(self, cmp):
label = self._label_for_compare(cmp)
if label:
self.labels[cmp].append(label)
def visit_Name(self, name):
self.priority += 1
if isinstance(name.ctx, ast.Store):
self.labels[name].append((name.variable, "assign", self.priority))
elif self.in_test:
self.labels[name].append((name.variable, None, self.priority))
def _label_for_unary_operand(self, op):
if not isinstance(op.op, ast.Not):
return None
if isinstance(op.operand, ast.UnaryOp):
return self._label_for_unary_operand(op.operand)
elif isinstance(op.operand, ast.Name):
self.priority += 1
return (op.operand.variable, None, self.priority)
elif isinstance(op.operand, ast.Compare):
return self._label_for_compare(op.operand)
return None
def visit_UnaryOp(self, op):
if not self.in_test:
return
label = self._label_for_unary_operand(op)
if label:
self.labels[op].append(label)
else:
self.visit(op.operand)
def visit_If(self, ifstmt):
# Looking for the pattern:
# if x: k = K0 else: k = K1
# the test is the split point, but the variable is `k`
self.in_test += 1
self.visit(ifstmt.test)
self.in_test -= 1
self.visit(ifstmt.body)
self.visit(ifstmt.orelse)
k1 = {}
ConstantAssignmentVisitor(k1).visit(ifstmt.body)
k2 = {}
ConstantAssignmentVisitor(k2).visit(ifstmt.orelse)
k = set(k1.keys()).union(k2.keys())
self.priority += 1
for var in k:
val = k1[var] if var in k1 else k2[var]
self.labels[ifstmt.test].append((var, val, self.priority))
def visit_Try(self, stmt):
# Looking for the pattern:
# if try: k = K0 except: k = K1
# the try is the split point, and the variable is `k`
self.generic_visit(stmt)
if not stmt.handlers or len(stmt.handlers) > 1:
return
k1 = {}
ConstantAssignmentVisitor(k1).visit(stmt.body)
k2 = {}
ConstantAssignmentVisitor(k2).visit(stmt.handlers[0])
k = set(k1.keys()).union(k2.keys())
self.priority += 1
for var in k:
val = k1[var] if var in k1 else k2[var]
self.labels[stmt].append((var, val, self.priority))
def visit_ClassExpr(self, node):
# Don't split over class definitions,
# as the presence of multiple ClassObjects for a
# single class can be confusing.
# The same applies to function definitions.
self.priority += 1
self.labels[node].append((None, "define", self.priority))
visit_FunctionExpr = visit_ClassExpr
class TryBodyAndHandlerVisitor(ScopedAstLabellingVisitor):
'''Visitor to gather all AST nodes under visited node
including, but not under `ExceptStmt`s.'''
def generic_visit(self, node):
if isinstance(node, ast.AstBase):
self.labels.add(node)
for _, _, value in iter_fields(node):
self.visit(value)
def visit_ExceptStmt(self, node):
#Do not visit node below this.
self.labels.add(node)
return
class ConstantAssignmentVisitor(ScopedAstLabellingVisitor):
'''Visitor to label assignments where RHS is a constant'''
def visit_Assign(self, asgn):
if not self.is_const(asgn.value):
return
for target in asgn.targets:
if hasattr(target, "variable"):
self.labels[target.variable] = asgn.value
def label_ast(ast_root):
'''Visits the AST, returning the labels'''
labels = defaultdict(list)
labeller = AstLabeller(labels)
labeller.generic_visit(ast_root)
return labels
def _is_branch(node, graph: FlowGraph):
'''Holds if `node` (in `graph`) is a branch point.'''
if len(graph.succ[node]) == 2 or isinstance(node.node, ast.Try):
return True
if len(graph.succ[node]) != 1:
return False
succ = graph.succ[node][0]
if not isinstance(succ.node, ast.UnaryOp):
return False
return _is_branch(succ, graph)
def label_cfg(graph: FlowGraph, ast_labels):
'''Copies labels from AST to CFG for branches and assignments.'''
cfg_labels = {}
for node, _ in graph.nodes():
if node.node not in ast_labels:
continue
labels = ast_labels[node.node]
if not labels:
continue
if _is_branch(node, graph) or labels[0][1] in ("assign", "define", "loop"):
cfg_labels[node] = labels
return cfg_labels
def usefully_comparable_types(o1, o2):
'''Holds if a test against object o1 can provide any
meaningful information w.r.t. to a test against o2.
'''
if o1 is None or o2 is None:
return True
return type(o1) is type(o2)
def exits_from_subtree(head, subtree, graph: FlowGraph):
'''Returns all nodes in `subtree`, that exit
the subtree and are reachable from `head`
'''
exits = set()
seen = set()
todo = set([head])
while todo:
node = todo.pop()
if node in seen:
continue
seen.add(node)
if not graph.succ[node]:
continue
is_exit = True
for succ in graph.succ[node]:
if succ.node in subtree:
todo.add(succ)
is_exit = False
if is_exit:
exits.add(node)
return exits
def get_split_heads(head, graph: FlowGraph):
'''Compute the split tails for the node `head`
That is, the set of nodes from which splitting should commence.
'''
if isinstance(head.node, ast.Try):
try_body = set()
TryBodyAndHandlerVisitor(try_body).visit(head.node)
if head.node.handlers:
try_body.add(head.node.handlers[0])
try_split_tails = exits_from_subtree(head, try_body, graph)
return try_split_tails
else:
return graph.succ[head]
def choose_split_points(graph: FlowGraph, cfg_labels):
'''Select the set of nodes to be the split heads for the graph,
from the given labels. A maximum of two points are chosen to avoid
excessive blow up.
'''
candidates = []
#Find pairs -- N1, N2 where N1 and N2 are tests on the same variable and the tests are similar.
labels = []
for node, label_list in cfg_labels.items():
for label in label_list:
labels.append((node, label[0], label[1], label[2]))
labels.sort(key=itemgetter(3))
for first_node, first_var, first_type, first_priority in labels:
if first_type in ("assign", "define"):
continue
#Avoid splitting if any class or function is defined later in scope.
if 'define' in [type for (_, _, type, priority) in labels if priority > first_priority]:
break
for second_node, second_var, second_type, second_priority in labels:
if second_var != first_var:
continue
# First node must dominate second node to be a viable splitting candidate.
# Quick check to avoid doing pointless dominance checks.
if first_priority >= second_priority:
continue
#Avoid splitting if variable is reassigned
if second_type == "assign":
break
if not graph.strictly_dominates(first_node, second_node):
continue
if not usefully_comparable_types(first_type, second_type):
continue
split_heads = get_split_heads(first_node, graph)
if len(split_heads) != 2:
continue
# Unless both of the split heads reach the second node,
# then there is no benefit to splitting.
for head in split_heads:
if not graph.strictly_dominates(first_node, head):
break
if not graph.reaches_while_dominated(head, second_node, first_node):
break
else:
candidates.append((first_node, split_heads, first_var, first_priority))
#Candidates is a list of (node, split-heads, variable, priority) tuples.
#Remove any duplicate nodes
candidates = deduplicate(candidates, 0, 3)
#Remove repeated splits on the same variable if more than MAX_SPLITS split and more than one variable.
if len(candidates) > MAX_SPLITS and len({c[2] for c in candidates}) > 1:
candidates = deduplicate(candidates, 2, 3)
# Return best two results, but must return in reverse priority order,
# so that splitting on one node does not remove a later one.
return [c[:2] for c in candidates[MAX_SPLITS-1::-1]]
def deduplicate(lst, col, sort_col):
'''De-duplicate list `lst` of tuples removing all but the first tuple containing
duplicates of `col`. Sort the result on `sort_col'''
dedupped = {}
for t in reversed(lst):
dedupped[t[col]] = t
return sorted(dedupped.values(), key=itemgetter(sort_col))