summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTimotej Lazar <timotej.lazar@fri.uni-lj.si>2017-12-29 18:23:42 +0100
committerTimotej Lazar <timotej.lazar@fri.uni-lj.si>2017-12-29 18:23:42 +0100
commit1ee2b0252602b6176ee02db5c4a77fb2daac0a0f (patch)
tree8d67f7bb7595dd36d30fa85c37e72f40202c924b
parentb10165d13cb7b05305e4246931a364c89ac7c258 (diff)
Import AST pattern stuff
-rw-r--r--regex/__init__.py208
-rw-r--r--regex/tree.py95
2 files changed, 303 insertions, 0 deletions
diff --git a/regex/__init__.py b/regex/__init__.py
new file mode 100644
index 0000000..7232349
--- /dev/null
+++ b/regex/__init__.py
@@ -0,0 +1,208 @@
+import ast
+import collections
+from itertools import chain, combinations, product
+import re
+import sys
+
+from .asttokens import asttokens
+
+from .tree import Tree
+
+def make_tree(code):
+ def make(node, parent=None):
+ if hasattr(node, 'first_token'):
+ data = {'start': node.first_token.startpos, 'end': node.last_token.endpos, 'ast': node}
+ else:
+ data = {'start': parent.first_token.startpos, 'end': parent.last_token.endpos, 'ast': parent}
+
+ if not hasattr(node, '_fields'):
+ return Tree(str(node), data=data)
+
+ children = []
+ for field, value in ast.iter_fields(node):
+ if value is None:
+ continue
+ if isinstance(value, ast.AST):
+ if not value._fields:
+ child = Tree(field, [Tree(value.__class__.__name__)], data=data)
+ else:
+ child = Tree(field, [make(value, node)], data=data)
+ elif isinstance(value, list):
+ child = Tree(field, [make(v, node) for v in value], data=data)
+ else:
+ child = Tree(field, [make(value, node)], data=data)
+ children.append(child)
+ return Tree(node.__class__.__name__, children, data=data)
+
+ pytree = asttokens.ASTTokens(code, parse=True).tree
+ return make(pytree)
+
+def stringify(tree):
+ return re.sub('\n *', ' ', tree.__repr__())
+
+# induce a pattern from the specified nodes in the tree
+def make_pattern(tree, nodes):
+ label = tree.label
+
+ if not tree:
+ if label == 'Return':
+ return Tree(label, [Tree('$')])
+ return None
+
+ if any(n is tree for n in nodes):
+ # remove variable names
+ if label == 'Name':
+ return Tree(label, [Tree('id'), tree[1]])
+ elif label == 'arg':
+ return Tree(label, [Tree('arg')])
+ else:
+ return tree
+
+ subpats = [make_pattern(child, nodes) for child in tree]
+ if not any(subpats):
+ return None
+
+ # extra nodes to include in the pattern
+ if label in ('Call', 'FunctionDef'):
+ # store function names
+ pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None])
+ elif label == 'args':
+ # (args arg1 … argN)
+ # store position of matching argument, by including (only) the
+ # root node of all nonmatching arguments
+ pat = Tree(label, [s if s is not None else Tree(tree[i].label) for i, s in enumerate(subpats)])
+ elif label == 'UnaryOp':
+ # store operator
+ pat = Tree(label, [tree[0], subpats[1]])
+ elif label == 'BinOp' or label == 'Compare':
+ # store operator and left and/or right side
+ children = []
+ if subpats[0]:
+ children += [subpats[0]]
+ children += [tree[1]]
+ if subpats[2]:
+ children += [subpats[2]]
+ pat = Tree(label, children)
+ else:
+ # store children if they match some part of the pattern
+ pat = Tree(label, [s for s in subpats if s is not None])
+ return pat
+
+def get_variables(tree):
+ variables = collections.defaultdict(list)
+ for node, parent in tree.subtrees_with_parents():
+ if not node:
+ continue
+ if node.label == 'Name' and parent.label != 'func':
+ name = node[0][0].label
+ variables[name].append(node)
+ elif node.label == 'FunctionDef':
+ # (FunctionDef name (args (arguments (args arg0 arg1 … argN))))
+ for arg in node[1][0][0]:
+ if arg:
+ name = arg[0][0].label
+ variables[name].append(arg)
+ return variables
+
+literals = {'Num', 'Str', 'NameConstant'}
+short_range = {
+ 'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp',
+ 'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List'
+}
+
+def get_patterns(code):
+ tree = make_tree(code)
+ variables = get_variables(tree)
+
+ for var, nodes in variables.items():
+ for selected in chain(
+ combinations(nodes, 2)):
+ pat = make_pattern(tree, selected)
+ if pat:
+ yield pat, selected
+
+ # yield patterns for variable-literal / literal-literal pairs
+ # yield patterns for singleton literals
+ # (only within certain statements and expressions)
+ def patterns_with_literals(node):
+ if not node:
+ return
+ if node.label in short_range:
+ vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)]
+ lits = [n for n in node.subtrees() if n.label in literals]
+ for selected in chain(
+ combinations(lits, 2),
+ product(lits, vars)):
+ pat = make_pattern(tree, selected)
+ if pat:
+ yield pat, selected
+ else:
+ for child in node:
+ yield from patterns_with_literals(child)
+
+ # extra nodes for Python
+ for n in tree.subtrees():
+ # empty return
+ if n.label == 'Return':
+ pat = make_pattern(tree, [n])
+ if pat:
+ yield pat, [n]
+
+ yield from patterns_with_literals(tree)
+
+def find_pattern(tree, pattern):
+ def find_pattern_aux(node, pattern, varname=None):
+ nodes = []
+ if pattern.label == node.label:
+ nodes = [node]
+
+ if node.label in ('id', 'arg'):
+ if not pattern:
+ while node:
+ node = node[0]
+ if varname is None:
+ varname = node.label
+ return [node]
+ elif node.label == varname:
+ return [node]
+ else:
+ return []
+
+ # subpatterns found so far among node’s children
+ index = 0
+ for subtree in node.subtrees():
+ if index >= len(pattern):
+ break
+ new_nodes = find_pattern_aux(subtree, pattern[index], varname)
+ if new_nodes:
+ nodes += new_nodes
+ index += 1
+ if index < len(pattern):
+ nodes = []
+
+ else:
+ for child in node:
+ nodes = find_pattern_aux(child, pattern, varname)
+ if nodes:
+ break
+ return nodes
+
+ for variable in get_variables(tree):
+ result = find_pattern_aux(tree, pattern, variable)
+ if result:
+ yield result
+
+def get_attributes(programs):
+ # extract attributes from training data
+ patterns = collections.defaultdict(list)
+ for code in programs:
+ for pat, nodes in get_patterns(code):
+ patterns[pat] += [code]
+
+ attrs = collections.OrderedDict()
+ for pat, progs in sorted(patterns.items(), key=lambda x: len(x[1]), reverse=True):
+ if len(progs) < 5:
+ break
+ attrs['regex-{}'.format(len(attrs))] = {'desc': pat, 'programs': progs}
+
+ return attrs
diff --git a/regex/tree.py b/regex/tree.py
new file mode 100644
index 0000000..1ff1b90
--- /dev/null
+++ b/regex/tree.py
@@ -0,0 +1,95 @@
+#!/usr/bin/python3
+
+class Tree(tuple):
+ def __new__ (cls, label, children=None, data=None):
+ return super(Tree, cls).__new__(cls, tuple(children) if children else tuple())
+
+ def __init__(self, label, children=None, data=None):
+ self.label = label
+ self.data = data
+
+ def __eq__(self, other):
+ return (self.__class__ is other.__class__ and
+ self.label == other.label and tuple.__eq__(self, other))
+ def __lt__(self, other):
+ if isinstance(other, Tree) and self.__class__ is other.__class__:
+ return (self.label, tuple(self)) < (other.label, tuple(other))
+ return self.__class__.__name__ < other.__class__.__name__
+ __ne__ = lambda self, other: not self == other
+ __gt__ = lambda self, other: not (self < other or self == other)
+ __le__ = lambda self, other: self < other or self == other
+ __ge__ = lambda self, other: not self < other
+
+ def __hash__(self):
+ return hash((self.label, tuple(self)))
+
+ def __repr__(self):
+ return self.to_string()
+
+ # adapted from https://en.wikipedia.org/wiki/S-expression#Parsing
+ @staticmethod
+ def from_string(string):
+ sexp = [[]]
+ word = ''
+ in_str = False
+ for char in string:
+ if char == '(' and not in_str:
+ sexp.append([])
+ elif char == ')' and not in_str:
+ if word:
+ sexp[-1].append(word if not sexp[-1] else Tree(word))
+ word = ''
+ temp = sexp.pop()
+ sexp[-1].append(Tree(temp[0], temp[1:]))
+ elif char in (' ', '\n', '\t') and not in_str:
+ if word:
+ sexp[-1].append(word if not sexp[-1] else Tree(word))
+ word = ''
+ elif char == '\"':
+ in_str = not in_str
+ else:
+ word += char
+ return sexp[0][0]
+
+ def to_string(self, inline=False, depth=0):
+ if not self:
+ return str(self.label)
+ string = '(' + self.label
+ if not inline and len(self) > 1 and len(list(self.subtrees())) > 8:
+ depth += 1
+ prefix = '\n' + ' '*depth
+ else:
+ prefix = ' '
+ for i, child in enumerate(self):
+ string += prefix
+ string += child.to_string(inline, depth)
+ string += ')'
+ return string
+
+ def to_graphviz(self):
+ string = 'digraph {\n'
+ string += 'ordering=out\n'
+ string += 'ranksep=0.3\n'
+
+ nodes = {}
+ edges = []
+ for i, (node, parent) in enumerate(self.subtrees_with_parents()):
+ nodes[id(node)] = i
+ if parent:
+ edges += [(id(parent), id(node))]
+ string += 'n{} [label="{}"];\n'.format(i, node.label)
+
+ for a, b in edges:
+ string += 'n{} -> n{};\n'.format(nodes[a], nodes[b])
+ string += '}\n'
+ return string
+
+ def subtrees(self):
+ yield self
+ for child in self:
+ yield from child.subtrees()
+
+ def subtrees_with_parents(self, parent=None):
+ yield (self, parent)
+ for child in self:
+ yield from child.subtrees_with_parents(self)