diff options
author | Timotej Lazar <timotej.lazar@fri.uni-lj.si> | 2017-12-29 18:23:42 +0100 |
---|---|---|
committer | Timotej Lazar <timotej.lazar@fri.uni-lj.si> | 2017-12-29 18:23:42 +0100 |
commit | 1ee2b0252602b6176ee02db5c4a77fb2daac0a0f (patch) | |
tree | 8d67f7bb7595dd36d30fa85c37e72f40202c924b | |
parent | b10165d13cb7b05305e4246931a364c89ac7c258 (diff) |
Import AST pattern stuff
-rw-r--r-- | regex/__init__.py | 208 | ||||
-rw-r--r-- | regex/tree.py | 95 |
2 files changed, 303 insertions, 0 deletions
diff --git a/regex/__init__.py b/regex/__init__.py new file mode 100644 index 0000000..7232349 --- /dev/null +++ b/regex/__init__.py @@ -0,0 +1,208 @@ +import ast +import collections +from itertools import chain, combinations, product +import re +import sys + +from .asttokens import asttokens + +from .tree import Tree + +def make_tree(code): + def make(node, parent=None): + if hasattr(node, 'first_token'): + data = {'start': node.first_token.startpos, 'end': node.last_token.endpos, 'ast': node} + else: + data = {'start': parent.first_token.startpos, 'end': parent.last_token.endpos, 'ast': parent} + + if not hasattr(node, '_fields'): + return Tree(str(node), data=data) + + children = [] + for field, value in ast.iter_fields(node): + if value is None: + continue + if isinstance(value, ast.AST): + if not value._fields: + child = Tree(field, [Tree(value.__class__.__name__)], data=data) + else: + child = Tree(field, [make(value, node)], data=data) + elif isinstance(value, list): + child = Tree(field, [make(v, node) for v in value], data=data) + else: + child = Tree(field, [make(value, node)], data=data) + children.append(child) + return Tree(node.__class__.__name__, children, data=data) + + pytree = asttokens.ASTTokens(code, parse=True).tree + return make(pytree) + +def stringify(tree): + return re.sub('\n *', ' ', tree.__repr__()) + +# induce a pattern from the specified nodes in the tree +def make_pattern(tree, nodes): + label = tree.label + + if not tree: + if label == 'Return': + return Tree(label, [Tree('$')]) + return None + + if any(n is tree for n in nodes): + # remove variable names + if label == 'Name': + return Tree(label, [Tree('id'), tree[1]]) + elif label == 'arg': + return Tree(label, [Tree('arg')]) + else: + return tree + + subpats = [make_pattern(child, nodes) for child in tree] + if not any(subpats): + return None + + # extra nodes to include in the pattern + if label in ('Call', 'FunctionDef'): + # store function names + pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None]) + elif label == 'args': + # (args arg1 … argN) + # store position of matching argument, by including (only) the + # root node of all nonmatching arguments + pat = Tree(label, [s if s is not None else Tree(tree[i].label) for i, s in enumerate(subpats)]) + elif label == 'UnaryOp': + # store operator + pat = Tree(label, [tree[0], subpats[1]]) + elif label == 'BinOp' or label == 'Compare': + # store operator and left and/or right side + children = [] + if subpats[0]: + children += [subpats[0]] + children += [tree[1]] + if subpats[2]: + children += [subpats[2]] + pat = Tree(label, children) + else: + # store children if they match some part of the pattern + pat = Tree(label, [s for s in subpats if s is not None]) + return pat + +def get_variables(tree): + variables = collections.defaultdict(list) + for node, parent in tree.subtrees_with_parents(): + if not node: + continue + if node.label == 'Name' and parent.label != 'func': + name = node[0][0].label + variables[name].append(node) + elif node.label == 'FunctionDef': + # (FunctionDef name (args (arguments (args arg0 arg1 … argN)))) + for arg in node[1][0][0]: + if arg: + name = arg[0][0].label + variables[name].append(arg) + return variables + +literals = {'Num', 'Str', 'NameConstant'} +short_range = { + 'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp', + 'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List' +} + +def get_patterns(code): + tree = make_tree(code) + variables = get_variables(tree) + + for var, nodes in variables.items(): + for selected in chain( + combinations(nodes, 2)): + pat = make_pattern(tree, selected) + if pat: + yield pat, selected + + # yield patterns for variable-literal / literal-literal pairs + # yield patterns for singleton literals + # (only within certain statements and expressions) + def patterns_with_literals(node): + if not node: + return + if node.label in short_range: + vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)] + lits = [n for n in node.subtrees() if n.label in literals] + for selected in chain( + combinations(lits, 2), + product(lits, vars)): + pat = make_pattern(tree, selected) + if pat: + yield pat, selected + else: + for child in node: + yield from patterns_with_literals(child) + + # extra nodes for Python + for n in tree.subtrees(): + # empty return + if n.label == 'Return': + pat = make_pattern(tree, [n]) + if pat: + yield pat, [n] + + yield from patterns_with_literals(tree) + +def find_pattern(tree, pattern): + def find_pattern_aux(node, pattern, varname=None): + nodes = [] + if pattern.label == node.label: + nodes = [node] + + if node.label in ('id', 'arg'): + if not pattern: + while node: + node = node[0] + if varname is None: + varname = node.label + return [node] + elif node.label == varname: + return [node] + else: + return [] + + # subpatterns found so far among node’s children + index = 0 + for subtree in node.subtrees(): + if index >= len(pattern): + break + new_nodes = find_pattern_aux(subtree, pattern[index], varname) + if new_nodes: + nodes += new_nodes + index += 1 + if index < len(pattern): + nodes = [] + + else: + for child in node: + nodes = find_pattern_aux(child, pattern, varname) + if nodes: + break + return nodes + + for variable in get_variables(tree): + result = find_pattern_aux(tree, pattern, variable) + if result: + yield result + +def get_attributes(programs): + # extract attributes from training data + patterns = collections.defaultdict(list) + for code in programs: + for pat, nodes in get_patterns(code): + patterns[pat] += [code] + + attrs = collections.OrderedDict() + for pat, progs in sorted(patterns.items(), key=lambda x: len(x[1]), reverse=True): + if len(progs) < 5: + break + attrs['regex-{}'.format(len(attrs))] = {'desc': pat, 'programs': progs} + + return attrs diff --git a/regex/tree.py b/regex/tree.py new file mode 100644 index 0000000..1ff1b90 --- /dev/null +++ b/regex/tree.py @@ -0,0 +1,95 @@ +#!/usr/bin/python3 + +class Tree(tuple): + def __new__ (cls, label, children=None, data=None): + return super(Tree, cls).__new__(cls, tuple(children) if children else tuple()) + + def __init__(self, label, children=None, data=None): + self.label = label + self.data = data + + def __eq__(self, other): + return (self.__class__ is other.__class__ and + self.label == other.label and tuple.__eq__(self, other)) + def __lt__(self, other): + if isinstance(other, Tree) and self.__class__ is other.__class__: + return (self.label, tuple(self)) < (other.label, tuple(other)) + return self.__class__.__name__ < other.__class__.__name__ + __ne__ = lambda self, other: not self == other + __gt__ = lambda self, other: not (self < other or self == other) + __le__ = lambda self, other: self < other or self == other + __ge__ = lambda self, other: not self < other + + def __hash__(self): + return hash((self.label, tuple(self))) + + def __repr__(self): + return self.to_string() + + # adapted from https://en.wikipedia.org/wiki/S-expression#Parsing + @staticmethod + def from_string(string): + sexp = [[]] + word = '' + in_str = False + for char in string: + if char == '(' and not in_str: + sexp.append([]) + elif char == ')' and not in_str: + if word: + sexp[-1].append(word if not sexp[-1] else Tree(word)) + word = '' + temp = sexp.pop() + sexp[-1].append(Tree(temp[0], temp[1:])) + elif char in (' ', '\n', '\t') and not in_str: + if word: + sexp[-1].append(word if not sexp[-1] else Tree(word)) + word = '' + elif char == '\"': + in_str = not in_str + else: + word += char + return sexp[0][0] + + def to_string(self, inline=False, depth=0): + if not self: + return str(self.label) + string = '(' + self.label + if not inline and len(self) > 1 and len(list(self.subtrees())) > 8: + depth += 1 + prefix = '\n' + ' '*depth + else: + prefix = ' ' + for i, child in enumerate(self): + string += prefix + string += child.to_string(inline, depth) + string += ')' + return string + + def to_graphviz(self): + string = 'digraph {\n' + string += 'ordering=out\n' + string += 'ranksep=0.3\n' + + nodes = {} + edges = [] + for i, (node, parent) in enumerate(self.subtrees_with_parents()): + nodes[id(node)] = i + if parent: + edges += [(id(parent), id(node))] + string += 'n{} [label="{}"];\n'.format(i, node.label) + + for a, b in edges: + string += 'n{} -> n{};\n'.format(nodes[a], nodes[b]) + string += '}\n' + return string + + def subtrees(self): + yield self + for child in self: + yield from child.subtrees() + + def subtrees_with_parents(self, parent=None): + yield (self, parent) + for child in self: + yield from child.subtrees_with_parents(self) |