summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--regex/__init__.py63
1 files changed, 47 insertions, 16 deletions
diff --git a/regex/__init__.py b/regex/__init__.py
index 14de8a9..ccc3c05 100644
--- a/regex/__init__.py
+++ b/regex/__init__.py
@@ -49,12 +49,19 @@ def make_pattern(tree, nodes):
return Tree(label, [Tree('$')])
return None
+ # found a node, do not recurse (much) further
if any(n is tree for n in nodes):
- # remove variable names
+ # omit variable names
if label == 'Name':
return Tree(label, [Tree('id'), tree[1]])
elif label == 'arg':
return Tree(label, [Tree('arg')])
+ # omit string value
+ elif label == 'Str':
+ return Tree(label)
+ # omit everything but function name
+ elif label == 'Call':
+ return Tree(label, [Tree('func', [Tree('Name', [Tree('id', [Tree(tree[0][0][0][0].label)])])])])
else:
return tree
@@ -70,7 +77,7 @@ def make_pattern(tree, nodes):
# (args arg1 … argN)
# store position of matching argument, by including (only) the
# root node of all nonmatching arguments
- pat = Tree(label, [s if s is not None else Tree(tree[i].label) for i, s in enumerate(subpats)])
+ pat = Tree(label, [s for i, s in enumerate(subpats) if s is not None])
elif label == 'UnaryOp':
# store operator
pat = Tree(label, [tree[0], subpats[1]])
@@ -114,12 +121,17 @@ def get_patterns(code):
tree = make_tree(code)
variables = get_variables(tree)
- for var, nodes in variables.items():
- for selected in chain(
- combinations(nodes, 2)):
- pat = make_pattern(tree, selected)
- if pat:
- yield pat, selected
+ def patterns_with_variables():
+ for var, nodes in variables.items():
+ if len(nodes) == 1:
+ pat = make_pattern(tree, nodes)
+ if pat:
+ yield pat, nodes
+ for selected in chain(
+ combinations(nodes, 2)):
+ pat = make_pattern(tree, selected)
+ if pat:
+ yield pat, selected
# yield patterns for variable-literal / literal-literal pairs
# yield patterns for singleton literals
@@ -131,6 +143,7 @@ def get_patterns(code):
vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)]
lits = [n for n in node.subtrees() if n.label in literals]
for selected in chain(
+ combinations(lits, 1),
combinations(lits, 2),
product(lits, vars)):
pat = make_pattern(tree, selected)
@@ -140,15 +153,33 @@ def get_patterns(code):
for child in node:
yield from patterns_with_literals(child)
- # extra nodes for Python
- for n in tree.subtrees():
- # empty return
- if n.label == 'Return':
- pat = make_pattern(tree, [n])
- if pat:
- yield pat, [n]
-
+ # yield patterns describing function calls
+ def patterns_with_calls(tree):
+ for node in tree.subtrees():
+ if node.label == 'Call':
+ pat = make_pattern(tree, [node])
+ if pat:
+ yield pat, [node]
+
+ # yield patterns describing control structures
+ def patterns_with_control_structures(node, path=None):
+ if path is None:
+ path = []
+ if node.label in ['FunctionDef', 'If', 'Else', 'For', 'While', 'Break', 'Continue', 'Pass', 'With', 'Raise', 'Try', 'Return']:
+ path += [node]
+ if not node and path:
+ children = []
+ for parent in path[::-1]:
+ pat = Tree(parent.label, children)
+ children = [pat]
+ yield pat, [node]
+ for child in node:
+ yield from patterns_with_control_structures(child, path)
+
+ yield from patterns_with_variables()
+ yield from patterns_with_calls(tree)
yield from patterns_with_literals(tree)
+ yield from patterns_with_control_structures(tree)
def get_attributes(programs):
# extract attributes from training data