summaryrefslogtreecommitdiff
path: root/regex/__init__.py
blob: 3738251cc20216b4f6e69f6f5ebaee787378d13f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import ast
import collections
from itertools import chain, combinations, product
import re
import sys

from .asttokens import asttokens

from .tree import Tree

def make_tree(code):
    def make(node, parent=None):
        if hasattr(node, 'first_token'):
            data = {'start': node.first_token.startpos, 'end': node.last_token.endpos, 'ast': node}
        else:
            data = {'start': parent.first_token.startpos, 'end': parent.last_token.endpos, 'ast': parent}

        if not hasattr(node, '_fields'):
            return Tree(str(node), data=data)

        children = []
        for field, value in ast.iter_fields(node):
            if value is None:
                continue
            if isinstance(value, ast.AST):
                if not value._fields:
                    child = Tree(field, [Tree(value.__class__.__name__)], data=data)
                else:
                    child = Tree(field, [make(value, node)], data=data)
            elif isinstance(value, list):
                child = Tree(field, [make(v, node) for v in value], data=data)
            else:
                child = Tree(field, [make(value, node)], data=data)
            children.append(child)
        return Tree(node.__class__.__name__, children, data=data)

    pytree = asttokens.ASTTokens(code, parse=True).tree
    return make(pytree)

# induce a pattern from the selected nodes in the tree
def make_pattern(tree, selected):
    label = tree.label

    # found a node, do not recurse further
    if any(n is tree for n in selected):
        if label == 'Name':
            # replace all variable names with 'VAR'
            return Tree(label, [Tree('id', [Tree('VAR')]), tree[1]])
        elif label == 'arg':
            # replace all variable names with 'VAR' (in function arguments)
            return Tree(label, [Tree('arg', [Tree('VAR')])])
        elif label == 'Str':
            # omit string value
            return Tree(label)
        elif label == 'Call':
            # omit arguments, keep the original function name
            if tree[0][0].label == 'Attribute':
                # method call
                return Tree(label, [Tree('func', [Tree('Attribute', tree[0][0][1:])])])
            else:
                return Tree(label, [tree[0]])
        else:
            return tree

    subpats = [make_pattern(child, selected) for child in tree]
    if all(s is None for s in subpats):
        # no selected nodes found in this branch, move along
        return None

    # at this point, one or more children of tree are ancestors to
    # some selected nodes

    # extra nodes to include in patterns, depending on node label
    if label == 'FunctionDef':
        # include function name
        pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None])
    elif label == 'Attribute':
        # include attribute name
        children = []
        if subpats[0] is not None:
            children += [subpats[0]]
        children += tree[1:]
        pat = Tree(label, children)
    elif label == 'Call':
        # include function/method name
        children = []
        if subpats[0] is not None:
            children += [subpats[0]]
        else:
            func = tree[0]
            if func[0].label == 'Attribute':
                # method call: include attribute (method) name
                children += [Tree('func', [Tree('Attribute', func[0][1:])])]
            else:
                # function call: include function name
                children += [func]
        children += [s for s in subpats[1:] if s is not None]
        pat = Tree(label, children)
    elif label == 'args':
        # (args arg1 … argN)
        # include at least an 'arg' node in the pattern for each
        # argument, allowing us to distinguish between subpatterns
        # found at different argument positions
        pat = Tree(label, [s if s is not None else Tree('arg') for s in subpats])
    elif label == 'UnaryOp':
        # store the operator
        pat = Tree(label, [tree[0], subpats[1]])
    elif label == 'BinOp' or label == 'Compare':
        # store the operator and left and/or right side
        children = []
        if subpats[0]:
            children += [subpats[0]]
        children += [tree[1]]
        if subpats[2]:
            children += [subpats[2]]
        pat = Tree(label, children)
    else:
        # store children if they match some part of the pattern
        pat = Tree(label, [s for s in subpats if s is not None])
    return pat

def get_variables(tree):
    variables = collections.defaultdict(list)
    for node, parent in tree.subtrees_with_parents():
        if node.label == 'Name' and parent.label != 'func':
            name = node[0][0].label
            variables[name].append(node)
        elif node.label == 'FunctionDef':
            # AST form: (FunctionDef name (args (arguments (args arg0 arg1 … argN))))
            for arg in node[1][0][0]:
                if arg:
                    name = arg[0][0].label
                    variables[name].append(arg)
        elif node.label == 'Lambda':
            # AST form: (Lambda (args (arguments (args (arg0 arg1 … argN)))))
            for arg in node[0][0][0]:
                if arg:
                    name = arg[0][0].label
                    variables[name].append(arg)
    return variables

def get_literals(tree):
    return [n for n in tree.subtrees() if n.label in {'Num', 'Str', 'NameConstant'}]

def get_calls(tree):
    return [n for n in tree.subtrees() if n.label == 'Call']

# non-statement expressions
short_range = {
    'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp',
    'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List'
}

def get_patterns(code):
    tree = make_tree(code)
    variables = get_variables(tree)

    def subsets(iterable):
        s = list(iterable)
        return chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1))

    def patterns_with_variables():
        for var, nodes in variables.items():
            # singletons
            if len(nodes) == 1:
                pat = make_pattern(tree, nodes)
                if pat:
                    yield pat, nodes
            # pairs of nodes with the same variable
            for selected in chain(
                    combinations(nodes, 2)):
                pat = make_pattern(tree, selected)
                if pat:
                    yield pat, selected

    # yield patterns for variable-literal / literal-literal pairs
    # yield patterns for singleton literals
    # (only within certain statements and expressions)
    def patterns_with_literals(node):
        if not node:
            return
        if node.label in short_range:
            vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)]
            literals = get_literals(node)
            for selected in chain(
                    subsets(literals),
                    product(literals, vars)):
                pat = make_pattern(tree, selected)
                if pat:
                    yield pat, selected
        else:
            for child in node:
                yield from patterns_with_literals(child)

    # yield patterns describing function calls
    def patterns_with_calls(tree):
        for node in get_calls(tree):
            pat = make_pattern(tree, [node])
            if pat:
                yield pat, [node]

    # yield patterns describing control structures
    def patterns_with_control_structures(node, path=None):
        if path is None:
            path = []
        if node.label in ['FunctionDef', 'If', 'Else', 'For', 'While', 'Break', 'Continue', 'Pass', 'With', 'Raise', 'Try', 'Return']:
            pat = Tree(node.label)
            for parent in path[::-1]:
                pat = Tree(parent.label, [pat])
            yield pat, [node]
            path += [node]
        for child in node:
            yield from patterns_with_control_structures(child, path)

    yield from patterns_with_variables()
    yield from patterns_with_calls(tree)
    yield from patterns_with_literals(tree)
    yield from patterns_with_control_structures(tree)

def get_attributes(programs):
    # extract attributes from training data
    patterns = collections.defaultdict(list)
    for code in programs:
        for pat, nodes in get_patterns(code):
            patterns[pat] += [code]

    attrs = collections.OrderedDict()
    for pat, progs in sorted(patterns.items(), key=lambda x: len(x[1]), reverse=True):
        if len(progs) < 5:
            break
        attrs['regex-{}'.format(len(attrs))] = {'desc': pat, 'programs': progs}

    return attrs