diff options
-rw-r--r-- | regex/__init__.py | 63 |
1 files changed, 47 insertions, 16 deletions
diff --git a/regex/__init__.py b/regex/__init__.py index 14de8a9..ccc3c05 100644 --- a/regex/__init__.py +++ b/regex/__init__.py @@ -49,12 +49,19 @@ def make_pattern(tree, nodes): return Tree(label, [Tree('$')]) return None + # found a node, do not recurse (much) further if any(n is tree for n in nodes): - # remove variable names + # omit variable names if label == 'Name': return Tree(label, [Tree('id'), tree[1]]) elif label == 'arg': return Tree(label, [Tree('arg')]) + # omit string value + elif label == 'Str': + return Tree(label) + # omit everything but function name + elif label == 'Call': + return Tree(label, [Tree('func', [Tree('Name', [Tree('id', [Tree(tree[0][0][0][0].label)])])])]) else: return tree @@ -70,7 +77,7 @@ def make_pattern(tree, nodes): # (args arg1 … argN) # store position of matching argument, by including (only) the # root node of all nonmatching arguments - pat = Tree(label, [s if s is not None else Tree(tree[i].label) for i, s in enumerate(subpats)]) + pat = Tree(label, [s for i, s in enumerate(subpats) if s is not None]) elif label == 'UnaryOp': # store operator pat = Tree(label, [tree[0], subpats[1]]) @@ -114,12 +121,17 @@ def get_patterns(code): tree = make_tree(code) variables = get_variables(tree) - for var, nodes in variables.items(): - for selected in chain( - combinations(nodes, 2)): - pat = make_pattern(tree, selected) - if pat: - yield pat, selected + def patterns_with_variables(): + for var, nodes in variables.items(): + if len(nodes) == 1: + pat = make_pattern(tree, nodes) + if pat: + yield pat, nodes + for selected in chain( + combinations(nodes, 2)): + pat = make_pattern(tree, selected) + if pat: + yield pat, selected # yield patterns for variable-literal / literal-literal pairs # yield patterns for singleton literals @@ -131,6 +143,7 @@ def get_patterns(code): vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)] lits = [n for n in node.subtrees() if n.label in literals] for selected in chain( + combinations(lits, 1), combinations(lits, 2), product(lits, vars)): pat = make_pattern(tree, selected) @@ -140,15 +153,33 @@ def get_patterns(code): for child in node: yield from patterns_with_literals(child) - # extra nodes for Python - for n in tree.subtrees(): - # empty return - if n.label == 'Return': - pat = make_pattern(tree, [n]) - if pat: - yield pat, [n] - + # yield patterns describing function calls + def patterns_with_calls(tree): + for node in tree.subtrees(): + if node.label == 'Call': + pat = make_pattern(tree, [node]) + if pat: + yield pat, [node] + + # yield patterns describing control structures + def patterns_with_control_structures(node, path=None): + if path is None: + path = [] + if node.label in ['FunctionDef', 'If', 'Else', 'For', 'While', 'Break', 'Continue', 'Pass', 'With', 'Raise', 'Try', 'Return']: + path += [node] + if not node and path: + children = [] + for parent in path[::-1]: + pat = Tree(parent.label, children) + children = [pat] + yield pat, [node] + for child in node: + yield from patterns_with_control_structures(child, path) + + yield from patterns_with_variables() + yield from patterns_with_calls(tree) yield from patterns_with_literals(tree) + yield from patterns_with_control_structures(tree) def get_attributes(programs): # extract attributes from training data |