summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTimotej Lazar <timotej.lazar@fri.uni-lj.si>2018-01-25 17:28:15 +0100
committerTimotej Lazar <timotej.lazar@fri.uni-lj.si>2018-01-25 17:28:15 +0100
commit628cda80e941b23f0ebb70419e787998de9f91eb (patch)
tree2fdce056eeb085963e6b1867a5441fe19e058614
parent14541040e8cf92f66971a234a1c5e7e60cb11c85 (diff)
regex: tweak make_pattern function
Fix a bug where literals were not included in patterns (»not x« is not the same as »x is None«).
-rw-r--r--regex/__init__.py102
1 files changed, 56 insertions, 46 deletions
diff --git a/regex/__init__.py b/regex/__init__.py
index ccc3c05..2a123a6 100644
--- a/regex/__init__.py
+++ b/regex/__init__.py
@@ -37,52 +37,50 @@ def make_tree(code):
pytree = asttokens.ASTTokens(code, parse=True).tree
return make(pytree)
-def stringify(tree):
- return re.sub('\n *', ' ', tree.__repr__())
-
-# induce a pattern from the specified nodes in the tree
-def make_pattern(tree, nodes):
+# induce a pattern from the selected nodes in the tree
+def make_pattern(tree, selected):
label = tree.label
- if not tree:
- if label == 'Return':
- return Tree(label, [Tree('$')])
- return None
-
- # found a node, do not recurse (much) further
- if any(n is tree for n in nodes):
- # omit variable names
+ # found a node, do not recurse further
+ if any(n is tree for n in selected):
if label == 'Name':
- return Tree(label, [Tree('id'), tree[1]])
+ # replace all variable names with 'VAR'
+ return Tree(label, [Tree('id', [Tree('VAR')]), tree[1]])
elif label == 'arg':
- return Tree(label, [Tree('arg')])
- # omit string value
+ # replace all variable names with 'VAR' (in function arguments)
+ return Tree(label, [Tree('arg', [Tree('VAR')])])
elif label == 'Str':
+ # omit string value
return Tree(label)
- # omit everything but function name
elif label == 'Call':
- return Tree(label, [Tree('func', [Tree('Name', [Tree('id', [Tree(tree[0][0][0][0].label)])])])])
+ # omit arguments, keep the original function name
+ return Tree(label, [tree[0]])
else:
return tree
- subpats = [make_pattern(child, nodes) for child in tree]
- if not any(subpats):
+ subpats = [make_pattern(child, selected) for child in tree]
+ if all(s is None for s in subpats):
+ # no selected nodes found in this branch, move along
return None
- # extra nodes to include in the pattern
- if label in ('Call', 'FunctionDef'):
- # store function names
+ # at this point, one or more children of tree are ancestors to
+ # some selected nodes
+
+ # extra nodes to include in patterns, depending on node label
+ if label in {'Call', 'FunctionDef'}:
+ # store function name
pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None])
elif label == 'args':
# (args arg1 … argN)
- # store position of matching argument, by including (only) the
- # root node of all nonmatching arguments
- pat = Tree(label, [s for i, s in enumerate(subpats) if s is not None])
+ # include at least an 'arg' node in the pattern for each
+ # argument, allowing us to distinguish between subpatterns
+ # found at different argument positions
+ pat = Tree(label, [s if s is not None else Tree('arg') for s in subpats])
elif label == 'UnaryOp':
- # store operator
+ # store the operator
pat = Tree(label, [tree[0], subpats[1]])
elif label == 'BinOp' or label == 'Compare':
- # store operator and left and/or right side
+ # store the operator and left and/or right side
children = []
if subpats[0]:
children += [subpats[0]]
@@ -98,20 +96,30 @@ def make_pattern(tree, nodes):
def get_variables(tree):
variables = collections.defaultdict(list)
for node, parent in tree.subtrees_with_parents():
- if not node:
- continue
if node.label == 'Name' and parent.label != 'func':
name = node[0][0].label
variables[name].append(node)
elif node.label == 'FunctionDef':
- # (FunctionDef name (args (arguments (args arg0 arg1 … argN))))
+ # AST form: (FunctionDef name (args (arguments (args arg0 arg1 … argN))))
for arg in node[1][0][0]:
if arg:
name = arg[0][0].label
variables[name].append(arg)
+ elif node.label == 'Lambda':
+ # AST form: (Lambda (args (arguments (args (arg0 arg1 … argN)))))
+ for arg in node[0][0][0]:
+ if arg:
+ name = arg[0][0].label
+ variables[name].append(arg)
return variables
-literals = {'Num', 'Str', 'NameConstant'}
+def get_literals(tree):
+ return [n for n in tree.subtrees() if n.label in {'Num', 'Str', 'NameConstant'}]
+
+def get_calls(tree):
+ return [n for n in tree.subtrees() if n.label == 'Call']
+
+# non-statement expressions
short_range = {
'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp',
'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List'
@@ -121,12 +129,18 @@ def get_patterns(code):
tree = make_tree(code)
variables = get_variables(tree)
+ def subsets(iterable):
+ s = list(iterable)
+ return chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1))
+
def patterns_with_variables():
for var, nodes in variables.items():
+ # singletons
if len(nodes) == 1:
pat = make_pattern(tree, nodes)
if pat:
yield pat, nodes
+ # pairs of nodes with the same variable
for selected in chain(
combinations(nodes, 2)):
pat = make_pattern(tree, selected)
@@ -141,11 +155,10 @@ def get_patterns(code):
return
if node.label in short_range:
vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)]
- lits = [n for n in node.subtrees() if n.label in literals]
+ literals = get_literals(node)
for selected in chain(
- combinations(lits, 1),
- combinations(lits, 2),
- product(lits, vars)):
+ subsets(literals),
+ product(literals, vars)):
pat = make_pattern(tree, selected)
if pat:
yield pat, selected
@@ -155,24 +168,21 @@ def get_patterns(code):
# yield patterns describing function calls
def patterns_with_calls(tree):
- for node in tree.subtrees():
- if node.label == 'Call':
- pat = make_pattern(tree, [node])
- if pat:
- yield pat, [node]
+ for node in get_calls(tree):
+ pat = make_pattern(tree, [node])
+ if pat:
+ yield pat, [node]
# yield patterns describing control structures
def patterns_with_control_structures(node, path=None):
if path is None:
path = []
if node.label in ['FunctionDef', 'If', 'Else', 'For', 'While', 'Break', 'Continue', 'Pass', 'With', 'Raise', 'Try', 'Return']:
- path += [node]
- if not node and path:
- children = []
+ pat = Tree(node.label)
for parent in path[::-1]:
- pat = Tree(parent.label, children)
- children = [pat]
+ pat = Tree(parent.label, [pat])
yield pat, [node]
+ path += [node]
for child in node:
yield from patterns_with_control_structures(child, path)