summaryrefslogtreecommitdiff
path: root/regex/__init__.py
blob: 14de8a9400d75d86fce9f9e080657673d895dc82 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import ast
import collections
from itertools import chain, combinations, product
import re
import sys

from .asttokens import asttokens

from .tree import Tree

def make_tree(code):
    def make(node, parent=None):
        if hasattr(node, 'first_token'):
            data = {'start': node.first_token.startpos, 'end': node.last_token.endpos, 'ast': node}
        else:
            data = {'start': parent.first_token.startpos, 'end': parent.last_token.endpos, 'ast': parent}

        if not hasattr(node, '_fields'):
            return Tree(str(node), data=data)

        children = []
        for field, value in ast.iter_fields(node):
            if value is None:
                continue
            if isinstance(value, ast.AST):
                if not value._fields:
                    child = Tree(field, [Tree(value.__class__.__name__)], data=data)
                else:
                    child = Tree(field, [make(value, node)], data=data)
            elif isinstance(value, list):
                child = Tree(field, [make(v, node) for v in value], data=data)
            else:
                child = Tree(field, [make(value, node)], data=data)
            children.append(child)
        return Tree(node.__class__.__name__, children, data=data)

    pytree = asttokens.ASTTokens(code, parse=True).tree
    return make(pytree)

def stringify(tree):
    return re.sub('\n *', ' ', tree.__repr__())

# induce a pattern from the specified nodes in the tree
def make_pattern(tree, nodes):
    label = tree.label

    if not tree:
        if label == 'Return':
            return Tree(label, [Tree('$')])
        return None

    if any(n is tree for n in nodes):
        # remove variable names
        if label == 'Name':
            return Tree(label, [Tree('id'), tree[1]])
        elif label == 'arg':
            return Tree(label, [Tree('arg')])
        else:
            return tree

    subpats = [make_pattern(child, nodes) for child in tree]
    if not any(subpats):
        return None

    # extra nodes to include in the pattern
    if label in ('Call', 'FunctionDef'):
        # store function names
        pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None])
    elif label == 'args':
        # (args arg1 … argN)
        # store position of matching argument, by including (only) the
        # root node of all nonmatching arguments
        pat = Tree(label, [s if s is not None else Tree(tree[i].label) for i, s in enumerate(subpats)])
    elif label == 'UnaryOp':
        # store operator
        pat = Tree(label, [tree[0], subpats[1]])
    elif label == 'BinOp' or label == 'Compare':
        # store operator and left and/or right side
        children = []
        if subpats[0]:
            children += [subpats[0]]
        children += [tree[1]]
        if subpats[2]:
            children += [subpats[2]]
        pat = Tree(label, children)
    else:
        # store children if they match some part of the pattern
        pat = Tree(label, [s for s in subpats if s is not None])
    return pat

def get_variables(tree):
    variables = collections.defaultdict(list)
    for node, parent in tree.subtrees_with_parents():
        if not node:
            continue
        if node.label == 'Name' and parent.label != 'func':
            name = node[0][0].label
            variables[name].append(node)
        elif node.label == 'FunctionDef':
            # (FunctionDef name (args (arguments (args arg0 arg1 … argN))))
            for arg in node[1][0][0]:
                if arg:
                    name = arg[0][0].label
                    variables[name].append(arg)
    return variables

literals = {'Num', 'Str', 'NameConstant'}
short_range = {
    'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp',
    'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List'
}

def get_patterns(code):
    tree = make_tree(code)
    variables = get_variables(tree)

    for var, nodes in variables.items():
        for selected in chain(
                combinations(nodes, 2)):
            pat = make_pattern(tree, selected)
            if pat:
                yield pat, selected

    # yield patterns for variable-literal / literal-literal pairs
    # yield patterns for singleton literals
    # (only within certain statements and expressions)
    def patterns_with_literals(node):
        if not node:
            return
        if node.label in short_range:
            vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)]
            lits = [n for n in node.subtrees() if n.label in literals]
            for selected in chain(
                    combinations(lits, 2),
                    product(lits, vars)):
                pat = make_pattern(tree, selected)
                if pat:
                    yield pat, selected
        else:
            for child in node:
                yield from patterns_with_literals(child)

    # extra nodes for Python
    for n in tree.subtrees():
        # empty return
        if n.label == 'Return':
            pat = make_pattern(tree, [n])
            if pat:
                yield pat, [n]

    yield from patterns_with_literals(tree)

def get_attributes(programs):
    # extract attributes from training data
    patterns = collections.defaultdict(list)
    for code in programs:
        for pat, nodes in get_patterns(code):
            patterns[pat] += [code]

    attrs = collections.OrderedDict()
    for pat, progs in sorted(patterns.items(), key=lambda x: len(x[1]), reverse=True):
        if len(progs) < 5:
            break
        attrs['regex-{}'.format(len(attrs))] = {'desc': pat, 'programs': progs}

    return attrs