summaryrefslogtreecommitdiff
path: root/regex/__init__.py
blob: 72323494d8dbc04459f0e7c333128dbdd31da37b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import ast
import collections
from itertools import chain, combinations, product
import re
import sys

from .asttokens import asttokens

from .tree import Tree

def make_tree(code):
    def make(node, parent=None):
        if hasattr(node, 'first_token'):
            data = {'start': node.first_token.startpos, 'end': node.last_token.endpos, 'ast': node}
        else:
            data = {'start': parent.first_token.startpos, 'end': parent.last_token.endpos, 'ast': parent}

        if not hasattr(node, '_fields'):
            return Tree(str(node), data=data)

        children = []
        for field, value in ast.iter_fields(node):
            if value is None:
                continue
            if isinstance(value, ast.AST):
                if not value._fields:
                    child = Tree(field, [Tree(value.__class__.__name__)], data=data)
                else:
                    child = Tree(field, [make(value, node)], data=data)
            elif isinstance(value, list):
                child = Tree(field, [make(v, node) for v in value], data=data)
            else:
                child = Tree(field, [make(value, node)], data=data)
            children.append(child)
        return Tree(node.__class__.__name__, children, data=data)

    pytree = asttokens.ASTTokens(code, parse=True).tree
    return make(pytree)

def stringify(tree):
    return re.sub('\n *', ' ', tree.__repr__())

# induce a pattern from the specified nodes in the tree
def make_pattern(tree, nodes):
    label = tree.label

    if not tree:
        if label == 'Return':
            return Tree(label, [Tree('$')])
        return None

    if any(n is tree for n in nodes):
        # remove variable names
        if label == 'Name':
            return Tree(label, [Tree('id'), tree[1]])
        elif label == 'arg':
            return Tree(label, [Tree('arg')])
        else:
            return tree

    subpats = [make_pattern(child, nodes) for child in tree]
    if not any(subpats):
        return None

    # extra nodes to include in the pattern
    if label in ('Call', 'FunctionDef'):
        # store function names
        pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None])
    elif label == 'args':
        # (args arg1 … argN)
        # store position of matching argument, by including (only) the
        # root node of all nonmatching arguments
        pat = Tree(label, [s if s is not None else Tree(tree[i].label) for i, s in enumerate(subpats)])
    elif label == 'UnaryOp':
        # store operator
        pat = Tree(label, [tree[0], subpats[1]])
    elif label == 'BinOp' or label == 'Compare':
        # store operator and left and/or right side
        children = []
        if subpats[0]:
            children += [subpats[0]]
            children += [tree[1]]
        if subpats[2]:
            children += [subpats[2]]
        pat = Tree(label, children)
    else:
        # store children if they match some part of the pattern
        pat = Tree(label, [s for s in subpats if s is not None])
    return pat

def get_variables(tree):
    variables = collections.defaultdict(list)
    for node, parent in tree.subtrees_with_parents():
        if not node:
            continue
        if node.label == 'Name' and parent.label != 'func':
            name = node[0][0].label
            variables[name].append(node)
        elif node.label == 'FunctionDef':
            # (FunctionDef name (args (arguments (args arg0 arg1 … argN))))
            for arg in node[1][0][0]:
                if arg:
                    name = arg[0][0].label
                    variables[name].append(arg)
    return variables

literals = {'Num', 'Str', 'NameConstant'}
short_range = {
    'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp',
    'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List'
}

def get_patterns(code):
    tree = make_tree(code)
    variables = get_variables(tree)

    for var, nodes in variables.items():
        for selected in chain(
                combinations(nodes, 2)):
            pat = make_pattern(tree, selected)
            if pat:
                yield pat, selected

    # yield patterns for variable-literal / literal-literal pairs
    # yield patterns for singleton literals
    # (only within certain statements and expressions)
    def patterns_with_literals(node):
        if not node:
            return
        if node.label in short_range:
            vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)]
            lits = [n for n in node.subtrees() if n.label in literals]
            for selected in chain(
                    combinations(lits, 2),
                    product(lits, vars)):
                pat = make_pattern(tree, selected)
                if pat:
                    yield pat, selected
        else:
            for child in node:
                yield from patterns_with_literals(child)

    # extra nodes for Python
    for n in tree.subtrees():
        # empty return
        if n.label == 'Return':
            pat = make_pattern(tree, [n])
            if pat:
                yield pat, [n]

    yield from patterns_with_literals(tree)

def find_pattern(tree, pattern):
    def find_pattern_aux(node, pattern, varname=None):
        nodes = []
        if pattern.label == node.label:
            nodes = [node]

            if node.label in ('id', 'arg'):
                if not pattern:
                    while node:
                        node = node[0]
                    if varname is None:
                        varname = node.label
                        return [node]
                    elif node.label == varname:
                        return [node]
                    else:
                        return []

            # subpatterns found so far among node’s children
            index = 0
            for subtree in node.subtrees():
                if index >= len(pattern):
                    break
                new_nodes = find_pattern_aux(subtree, pattern[index], varname)
                if new_nodes:
                    nodes += new_nodes
                    index += 1
            if index < len(pattern):
                nodes = []

        else:
            for child in node:
                nodes = find_pattern_aux(child, pattern, varname)
                if nodes:
                    break
        return nodes

    for variable in get_variables(tree):
        result = find_pattern_aux(tree, pattern, variable)
        if result:
            yield result

def get_attributes(programs):
    # extract attributes from training data
    patterns = collections.defaultdict(list)
    for code in programs:
        for pat, nodes in get_patterns(code):
            patterns[pat] += [code]

    attrs = collections.OrderedDict()
    for pat, progs in sorted(patterns.items(), key=lambda x: len(x[1]), reverse=True):
        if len(progs) < 5:
            break
        attrs['regex-{}'.format(len(attrs))] = {'desc': pat, 'programs': progs}

    return attrs