summaryrefslogtreecommitdiff
path: root/regex/__init__.py
blob: ccc3c05c706cc7787bae0e51cb88ef8847a0cb72 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import ast
import collections
from itertools import chain, combinations, product
import re
import sys

from .asttokens import asttokens

from .tree import Tree

def make_tree(code):
    def make(node, parent=None):
        if hasattr(node, 'first_token'):
            data = {'start': node.first_token.startpos, 'end': node.last_token.endpos, 'ast': node}
        else:
            data = {'start': parent.first_token.startpos, 'end': parent.last_token.endpos, 'ast': parent}

        if not hasattr(node, '_fields'):
            return Tree(str(node), data=data)

        children = []
        for field, value in ast.iter_fields(node):
            if value is None:
                continue
            if isinstance(value, ast.AST):
                if not value._fields:
                    child = Tree(field, [Tree(value.__class__.__name__)], data=data)
                else:
                    child = Tree(field, [make(value, node)], data=data)
            elif isinstance(value, list):
                child = Tree(field, [make(v, node) for v in value], data=data)
            else:
                child = Tree(field, [make(value, node)], data=data)
            children.append(child)
        return Tree(node.__class__.__name__, children, data=data)

    pytree = asttokens.ASTTokens(code, parse=True).tree
    return make(pytree)

def stringify(tree):
    return re.sub('\n *', ' ', tree.__repr__())

# induce a pattern from the specified nodes in the tree
def make_pattern(tree, nodes):
    label = tree.label

    if not tree:
        if label == 'Return':
            return Tree(label, [Tree('$')])
        return None

    # found a node, do not recurse (much) further
    if any(n is tree for n in nodes):
        # omit variable names
        if label == 'Name':
            return Tree(label, [Tree('id'), tree[1]])
        elif label == 'arg':
            return Tree(label, [Tree('arg')])
        # omit string value
        elif label == 'Str':
            return Tree(label)
        # omit everything but function name
        elif label == 'Call':
            return Tree(label, [Tree('func', [Tree('Name', [Tree('id', [Tree(tree[0][0][0][0].label)])])])])
        else:
            return tree

    subpats = [make_pattern(child, nodes) for child in tree]
    if not any(subpats):
        return None

    # extra nodes to include in the pattern
    if label in ('Call', 'FunctionDef'):
        # store function names
        pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None])
    elif label == 'args':
        # (args arg1 … argN)
        # store position of matching argument, by including (only) the
        # root node of all nonmatching arguments
        pat = Tree(label, [s for i, s in enumerate(subpats) if s is not None])
    elif label == 'UnaryOp':
        # store operator
        pat = Tree(label, [tree[0], subpats[1]])
    elif label == 'BinOp' or label == 'Compare':
        # store operator and left and/or right side
        children = []
        if subpats[0]:
            children += [subpats[0]]
        children += [tree[1]]
        if subpats[2]:
            children += [subpats[2]]
        pat = Tree(label, children)
    else:
        # store children if they match some part of the pattern
        pat = Tree(label, [s for s in subpats if s is not None])
    return pat

def get_variables(tree):
    variables = collections.defaultdict(list)
    for node, parent in tree.subtrees_with_parents():
        if not node:
            continue
        if node.label == 'Name' and parent.label != 'func':
            name = node[0][0].label
            variables[name].append(node)
        elif node.label == 'FunctionDef':
            # (FunctionDef name (args (arguments (args arg0 arg1 … argN))))
            for arg in node[1][0][0]:
                if arg:
                    name = arg[0][0].label
                    variables[name].append(arg)
    return variables

literals = {'Num', 'Str', 'NameConstant'}
short_range = {
    'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp',
    'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List'
}

def get_patterns(code):
    tree = make_tree(code)
    variables = get_variables(tree)

    def patterns_with_variables():
        for var, nodes in variables.items():
            if len(nodes) == 1:
                pat = make_pattern(tree, nodes)
                if pat:
                    yield pat, nodes
            for selected in chain(
                    combinations(nodes, 2)):
                pat = make_pattern(tree, selected)
                if pat:
                    yield pat, selected

    # yield patterns for variable-literal / literal-literal pairs
    # yield patterns for singleton literals
    # (only within certain statements and expressions)
    def patterns_with_literals(node):
        if not node:
            return
        if node.label in short_range:
            vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)]
            lits = [n for n in node.subtrees() if n.label in literals]
            for selected in chain(
                    combinations(lits, 1),
                    combinations(lits, 2),
                    product(lits, vars)):
                pat = make_pattern(tree, selected)
                if pat:
                    yield pat, selected
        else:
            for child in node:
                yield from patterns_with_literals(child)

    # yield patterns describing function calls
    def patterns_with_calls(tree):
        for node in tree.subtrees():
            if node.label == 'Call':
                pat = make_pattern(tree, [node])
                if pat:
                    yield pat, [node]

    # yield patterns describing control structures
    def patterns_with_control_structures(node, path=None):
        if path is None:
            path = []
        if node.label in ['FunctionDef', 'If', 'Else', 'For', 'While', 'Break', 'Continue', 'Pass', 'With', 'Raise', 'Try', 'Return']:
            path += [node]
        if not node and path:
            children = []
            for parent in path[::-1]:
                pat = Tree(parent.label, children)
                children = [pat]
            yield pat, [node]
        for child in node:
            yield from patterns_with_control_structures(child, path)

    yield from patterns_with_variables()
    yield from patterns_with_calls(tree)
    yield from patterns_with_literals(tree)
    yield from patterns_with_control_structures(tree)

def get_attributes(programs):
    # extract attributes from training data
    patterns = collections.defaultdict(list)
    for code in programs:
        for pat, nodes in get_patterns(code):
            patterns[pat] += [code]

    attrs = collections.OrderedDict()
    for pat, progs in sorted(patterns.items(), key=lambda x: len(x[1]), reverse=True):
        if len(progs) < 5:
            break
        attrs['regex-{}'.format(len(attrs))] = {'desc': pat, 'programs': progs}

    return attrs