diff options
-rw-r--r-- | regex/__init__.py | 102 |
1 files changed, 56 insertions, 46 deletions
diff --git a/regex/__init__.py b/regex/__init__.py index ccc3c05..2a123a6 100644 --- a/regex/__init__.py +++ b/regex/__init__.py @@ -37,52 +37,50 @@ def make_tree(code): pytree = asttokens.ASTTokens(code, parse=True).tree return make(pytree) -def stringify(tree): - return re.sub('\n *', ' ', tree.__repr__()) - -# induce a pattern from the specified nodes in the tree -def make_pattern(tree, nodes): +# induce a pattern from the selected nodes in the tree +def make_pattern(tree, selected): label = tree.label - if not tree: - if label == 'Return': - return Tree(label, [Tree('$')]) - return None - - # found a node, do not recurse (much) further - if any(n is tree for n in nodes): - # omit variable names + # found a node, do not recurse further + if any(n is tree for n in selected): if label == 'Name': - return Tree(label, [Tree('id'), tree[1]]) + # replace all variable names with 'VAR' + return Tree(label, [Tree('id', [Tree('VAR')]), tree[1]]) elif label == 'arg': - return Tree(label, [Tree('arg')]) - # omit string value + # replace all variable names with 'VAR' (in function arguments) + return Tree(label, [Tree('arg', [Tree('VAR')])]) elif label == 'Str': + # omit string value return Tree(label) - # omit everything but function name elif label == 'Call': - return Tree(label, [Tree('func', [Tree('Name', [Tree('id', [Tree(tree[0][0][0][0].label)])])])]) + # omit arguments, keep the original function name + return Tree(label, [tree[0]]) else: return tree - subpats = [make_pattern(child, nodes) for child in tree] - if not any(subpats): + subpats = [make_pattern(child, selected) for child in tree] + if all(s is None for s in subpats): + # no selected nodes found in this branch, move along return None - # extra nodes to include in the pattern - if label in ('Call', 'FunctionDef'): - # store function names + # at this point, one or more children of tree are ancestors to + # some selected nodes + + # extra nodes to include in patterns, depending on node label + if label in {'Call', 'FunctionDef'}: + # store function name pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None]) elif label == 'args': # (args arg1 … argN) - # store position of matching argument, by including (only) the - # root node of all nonmatching arguments - pat = Tree(label, [s for i, s in enumerate(subpats) if s is not None]) + # include at least an 'arg' node in the pattern for each + # argument, allowing us to distinguish between subpatterns + # found at different argument positions + pat = Tree(label, [s if s is not None else Tree('arg') for s in subpats]) elif label == 'UnaryOp': - # store operator + # store the operator pat = Tree(label, [tree[0], subpats[1]]) elif label == 'BinOp' or label == 'Compare': - # store operator and left and/or right side + # store the operator and left and/or right side children = [] if subpats[0]: children += [subpats[0]] @@ -98,20 +96,30 @@ def make_pattern(tree, nodes): def get_variables(tree): variables = collections.defaultdict(list) for node, parent in tree.subtrees_with_parents(): - if not node: - continue if node.label == 'Name' and parent.label != 'func': name = node[0][0].label variables[name].append(node) elif node.label == 'FunctionDef': - # (FunctionDef name (args (arguments (args arg0 arg1 … argN)))) + # AST form: (FunctionDef name (args (arguments (args arg0 arg1 … argN)))) for arg in node[1][0][0]: if arg: name = arg[0][0].label variables[name].append(arg) + elif node.label == 'Lambda': + # AST form: (Lambda (args (arguments (args (arg0 arg1 … argN))))) + for arg in node[0][0][0]: + if arg: + name = arg[0][0].label + variables[name].append(arg) return variables -literals = {'Num', 'Str', 'NameConstant'} +def get_literals(tree): + return [n for n in tree.subtrees() if n.label in {'Num', 'Str', 'NameConstant'}] + +def get_calls(tree): + return [n for n in tree.subtrees() if n.label == 'Call'] + +# non-statement expressions short_range = { 'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp', 'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List' @@ -121,12 +129,18 @@ def get_patterns(code): tree = make_tree(code) variables = get_variables(tree) + def subsets(iterable): + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1)) + def patterns_with_variables(): for var, nodes in variables.items(): + # singletons if len(nodes) == 1: pat = make_pattern(tree, nodes) if pat: yield pat, nodes + # pairs of nodes with the same variable for selected in chain( combinations(nodes, 2)): pat = make_pattern(tree, selected) @@ -141,11 +155,10 @@ def get_patterns(code): return if node.label in short_range: vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)] - lits = [n for n in node.subtrees() if n.label in literals] + literals = get_literals(node) for selected in chain( - combinations(lits, 1), - combinations(lits, 2), - product(lits, vars)): + subsets(literals), + product(literals, vars)): pat = make_pattern(tree, selected) if pat: yield pat, selected @@ -155,24 +168,21 @@ def get_patterns(code): # yield patterns describing function calls def patterns_with_calls(tree): - for node in tree.subtrees(): - if node.label == 'Call': - pat = make_pattern(tree, [node]) - if pat: - yield pat, [node] + for node in get_calls(tree): + pat = make_pattern(tree, [node]) + if pat: + yield pat, [node] # yield patterns describing control structures def patterns_with_control_structures(node, path=None): if path is None: path = [] if node.label in ['FunctionDef', 'If', 'Else', 'For', 'While', 'Break', 'Continue', 'Pass', 'With', 'Raise', 'Try', 'Return']: - path += [node] - if not node and path: - children = [] + pat = Tree(node.label) for parent in path[::-1]: - pat = Tree(parent.label, children) - children = [pat] + pat = Tree(parent.label, [pat]) yield pat, [node] + path += [node] for child in node: yield from patterns_with_control_structures(child, path) |