From 1ee2b0252602b6176ee02db5c4a77fb2daac0a0f Mon Sep 17 00:00:00 2001
From: Timotej Lazar <timotej.lazar@fri.uni-lj.si>
Date: Fri, 29 Dec 2017 18:23:42 +0100
Subject: Import AST pattern stuff

---
 regex/__init__.py | 208 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 regex/tree.py     |  95 +++++++++++++++++++++++++
 2 files changed, 303 insertions(+)
 create mode 100644 regex/__init__.py
 create mode 100644 regex/tree.py

(limited to 'regex')

diff --git a/regex/__init__.py b/regex/__init__.py
new file mode 100644
index 0000000..7232349
--- /dev/null
+++ b/regex/__init__.py
@@ -0,0 +1,208 @@
+import ast
+import collections
+from itertools import chain, combinations, product
+import re
+import sys
+
+from .asttokens import asttokens
+
+from .tree import Tree
+
+def make_tree(code):
+    def make(node, parent=None):
+        if hasattr(node, 'first_token'):
+            data = {'start': node.first_token.startpos, 'end': node.last_token.endpos, 'ast': node}
+        else:
+            data = {'start': parent.first_token.startpos, 'end': parent.last_token.endpos, 'ast': parent}
+
+        if not hasattr(node, '_fields'):
+            return Tree(str(node), data=data)
+
+        children = []
+        for field, value in ast.iter_fields(node):
+            if value is None:
+                continue
+            if isinstance(value, ast.AST):
+                if not value._fields:
+                    child = Tree(field, [Tree(value.__class__.__name__)], data=data)
+                else:
+                    child = Tree(field, [make(value, node)], data=data)
+            elif isinstance(value, list):
+                child = Tree(field, [make(v, node) for v in value], data=data)
+            else:
+                child = Tree(field, [make(value, node)], data=data)
+            children.append(child)
+        return Tree(node.__class__.__name__, children, data=data)
+
+    pytree = asttokens.ASTTokens(code, parse=True).tree
+    return make(pytree)
+
+def stringify(tree):
+    return re.sub('\n *', ' ', tree.__repr__())
+
+# induce a pattern from the specified nodes in the tree
+def make_pattern(tree, nodes):
+    label = tree.label
+
+    if not tree:
+        if label == 'Return':
+            return Tree(label, [Tree('$')])
+        return None
+
+    if any(n is tree for n in nodes):
+        # remove variable names
+        if label == 'Name':
+            return Tree(label, [Tree('id'), tree[1]])
+        elif label == 'arg':
+            return Tree(label, [Tree('arg')])
+        else:
+            return tree
+
+    subpats = [make_pattern(child, nodes) for child in tree]
+    if not any(subpats):
+        return None
+
+    # extra nodes to include in the pattern
+    if label in ('Call', 'FunctionDef'):
+        # store function names
+        pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None])
+    elif label == 'args':
+        # (args arg1 … argN)
+        # store position of matching argument, by including (only) the
+        # root node of all nonmatching arguments
+        pat = Tree(label, [s if s is not None else Tree(tree[i].label) for i, s in enumerate(subpats)])
+    elif label == 'UnaryOp':
+        # store operator
+        pat = Tree(label, [tree[0], subpats[1]])
+    elif label == 'BinOp' or label == 'Compare':
+        # store operator and left and/or right side
+        children = []
+        if subpats[0]:
+            children += [subpats[0]]
+            children += [tree[1]]
+        if subpats[2]:
+            children += [subpats[2]]
+        pat = Tree(label, children)
+    else:
+        # store children if they match some part of the pattern
+        pat = Tree(label, [s for s in subpats if s is not None])
+    return pat
+
+def get_variables(tree):
+    variables = collections.defaultdict(list)
+    for node, parent in tree.subtrees_with_parents():
+        if not node:
+            continue
+        if node.label == 'Name' and parent.label != 'func':
+            name = node[0][0].label
+            variables[name].append(node)
+        elif node.label == 'FunctionDef':
+            # (FunctionDef name (args (arguments (args arg0 arg1 … argN))))
+            for arg in node[1][0][0]:
+                if arg:
+                    name = arg[0][0].label
+                    variables[name].append(arg)
+    return variables
+
+literals = {'Num', 'Str', 'NameConstant'}
+short_range = {
+    'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp',
+    'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List'
+}
+
+def get_patterns(code):
+    tree = make_tree(code)
+    variables = get_variables(tree)
+
+    for var, nodes in variables.items():
+        for selected in chain(
+                combinations(nodes, 2)):
+            pat = make_pattern(tree, selected)
+            if pat:
+                yield pat, selected
+
+    # yield patterns for variable-literal / literal-literal pairs
+    # yield patterns for singleton literals
+    # (only within certain statements and expressions)
+    def patterns_with_literals(node):
+        if not node:
+            return
+        if node.label in short_range:
+            vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)]
+            lits = [n for n in node.subtrees() if n.label in literals]
+            for selected in chain(
+                    combinations(lits, 2),
+                    product(lits, vars)):
+                pat = make_pattern(tree, selected)
+                if pat:
+                    yield pat, selected
+        else:
+            for child in node:
+                yield from patterns_with_literals(child)
+
+    # extra nodes for Python
+    for n in tree.subtrees():
+        # empty return
+        if n.label == 'Return':
+            pat = make_pattern(tree, [n])
+            if pat:
+                yield pat, [n]
+
+    yield from patterns_with_literals(tree)
+
+def find_pattern(tree, pattern):
+    def find_pattern_aux(node, pattern, varname=None):
+        nodes = []
+        if pattern.label == node.label:
+            nodes = [node]
+
+            if node.label in ('id', 'arg'):
+                if not pattern:
+                    while node:
+                        node = node[0]
+                    if varname is None:
+                        varname = node.label
+                        return [node]
+                    elif node.label == varname:
+                        return [node]
+                    else:
+                        return []
+
+            # subpatterns found so far among node’s children
+            index = 0
+            for subtree in node.subtrees():
+                if index >= len(pattern):
+                    break
+                new_nodes = find_pattern_aux(subtree, pattern[index], varname)
+                if new_nodes:
+                    nodes += new_nodes
+                    index += 1
+            if index < len(pattern):
+                nodes = []
+
+        else:
+            for child in node:
+                nodes = find_pattern_aux(child, pattern, varname)
+                if nodes:
+                    break
+        return nodes
+
+    for variable in get_variables(tree):
+        result = find_pattern_aux(tree, pattern, variable)
+        if result:
+            yield result
+
+def get_attributes(programs):
+    # extract attributes from training data
+    patterns = collections.defaultdict(list)
+    for code in programs:
+        for pat, nodes in get_patterns(code):
+            patterns[pat] += [code]
+
+    attrs = collections.OrderedDict()
+    for pat, progs in sorted(patterns.items(), key=lambda x: len(x[1]), reverse=True):
+        if len(progs) < 5:
+            break
+        attrs['regex-{}'.format(len(attrs))] = {'desc': pat, 'programs': progs}
+
+    return attrs
diff --git a/regex/tree.py b/regex/tree.py
new file mode 100644
index 0000000..1ff1b90
--- /dev/null
+++ b/regex/tree.py
@@ -0,0 +1,95 @@
+#!/usr/bin/python3
+
+class Tree(tuple):
+    def __new__ (cls, label, children=None, data=None):
+        return super(Tree, cls).__new__(cls, tuple(children) if children else tuple())
+
+    def __init__(self, label, children=None, data=None):
+        self.label = label
+        self.data = data
+
+    def __eq__(self, other):
+        return (self.__class__ is other.__class__ and
+                self.label == other.label and tuple.__eq__(self, other))
+    def __lt__(self, other):
+        if isinstance(other, Tree) and self.__class__ is other.__class__:
+            return (self.label, tuple(self)) < (other.label, tuple(other))
+        return self.__class__.__name__ < other.__class__.__name__
+    __ne__ = lambda self, other: not self == other
+    __gt__ = lambda self, other: not (self < other or self == other)
+    __le__ = lambda self, other: self < other or self == other
+    __ge__ = lambda self, other: not self < other
+
+    def __hash__(self):
+        return hash((self.label, tuple(self)))
+
+    def __repr__(self):
+        return self.to_string()
+
+    # adapted from https://en.wikipedia.org/wiki/S-expression#Parsing
+    @staticmethod
+    def from_string(string):
+        sexp = [[]]
+        word = ''
+        in_str = False
+        for char in string:
+            if char == '(' and not in_str:
+                sexp.append([])
+            elif char == ')' and not in_str:
+                if word:
+                    sexp[-1].append(word if not sexp[-1] else Tree(word))
+                    word = ''
+                temp = sexp.pop()
+                sexp[-1].append(Tree(temp[0], temp[1:]))
+            elif char in (' ', '\n', '\t') and not in_str:
+                if word:
+                    sexp[-1].append(word if not sexp[-1] else Tree(word))
+                    word = ''
+            elif char == '\"':
+                in_str = not in_str
+            else:
+                word += char
+        return sexp[0][0]
+
+    def to_string(self, inline=False, depth=0):
+        if not self:
+            return str(self.label)
+        string = '(' + self.label
+        if not inline and len(self) > 1 and len(list(self.subtrees())) > 8:
+            depth += 1
+            prefix = '\n' + '  '*depth
+        else:
+            prefix = ' '
+        for i, child in enumerate(self):
+            string += prefix
+            string += child.to_string(inline, depth)
+        string += ')'
+        return string
+
+    def to_graphviz(self):
+        string = 'digraph {\n'
+        string += 'ordering=out\n'
+        string += 'ranksep=0.3\n'
+
+        nodes = {}
+        edges = []
+        for i, (node, parent) in enumerate(self.subtrees_with_parents()):
+            nodes[id(node)] = i
+            if parent:
+                edges += [(id(parent), id(node))]
+            string += 'n{} [label="{}"];\n'.format(i, node.label)
+
+        for a, b in edges:
+            string += 'n{} -> n{};\n'.format(nodes[a], nodes[b])
+        string += '}\n'
+        return string
+
+    def subtrees(self):
+        yield self
+        for child in self:
+            yield from child.subtrees()
+
+    def subtrees_with_parents(self, parent=None):
+        yield (self, parent)
+        for child in self:
+            yield from child.subtrees_with_parents(self)
-- 
cgit v1.2.1