summaryrefslogtreecommitdiff
path: root/monkey/monkey.py
blob: 8e805f5cdf684a0bbc59e668650b8ec43ca635aa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
#!/usr/bin/python3

import collections
import math
import pickle
import sys
import time

from termcolor import colored

from . import db
from .action import parse
from .edits import classify_edits, clean_graph, edit_graph, get_edits_from_traces
from .graph import Node, graphviz
from .prolog.engine import PrologEngine
from .prolog.util import compose, decompose, map_vars, rename_vars, stringify
from .util import PQueue, Token, indent

# score a program (a list of lines) according to lines distribution
def score(program, lines):
    result = 1
    for line in program:
        line_normal = list(line)
        rename_vars(line_normal)
        line_normal = tuple(line_normal)
        result *= lines.get(line_normal, 0.01)

    if len(program) == 0 or result == 0:
        return 0.01
    return math.pow(result, 1/len(program))

# find a sequence of edits that fixes [code]
def fix(name, code, edits, timeout=30, debug=False):
    todo = PQueue()  # priority queue of candidate solutions
    done = set()     # set of already-analyzed solutions

    # Add a new candidate solution ([lines]+[rules]) to the priority queue.
    # This solution is generated by applying [step] with [cost] to [prev] task.
    def add_task(lines, rules, prev=None, step=None, cost=None):
        if prev is None:
            path = ()
            path_cost = 1.0
        else:
            path = tuple(list(prev[1]) + [step])
            path_cost = prev[2] * cost
        todo.push(((tuple(lines), tuple(rules)), path, path_cost), -path_cost)

    lines, rules = decompose(code)
    add_task(lines, rules)

    inserts, removes, changes = classify_edits(edits)
    start_time = time.monotonic()
    n_tested = 0
    while True:
        total_time = time.monotonic() - start_time
        if total_time > timeout:
            break

        task = todo.pop()
        if task == None:
            break

        (lines, rules), path, path_cost = task
        code = compose(lines, rules)
        if code in done:
            continue
        done.add(code)

        if debug:
            print('Cost {:.12f}'.format(path_cost))
            for line, (before, after) in path:
                print('line ' + str(line) + ':\t' + stringify(before) + ' → ' + stringify(after))

        # if the code is correct, we are done
        try:
            if test(name, code):
                return code, path, total_time, n_tested
        except:
            pass
        n_tested += 1

        # otherwise generate new solutions
        rule_no = 0
        for start, end in rules:
            rule = lines[start:end]
            rule_tokens = [t for line in rule for t in line]

            for line_idx in range(start, end):
                line = lines[line_idx]

                line_normal = list(line)
                rename_vars(line_normal)
                line_normal = tuple(line_normal)

                seen = False
                for (before, after), cost in changes.items():
                    if line_normal == before:
                        seen = True
                        mapping = map_vars(before, after, line, rule_tokens)
                        after_real = tuple([t if t.type != 'VARIABLE' else Token('VARIABLE', mapping[t.val]) for t in after])
                        new_lines = lines[:line_idx] + (after_real,) + lines[line_idx+1:]
                        new_step = ((rule_no, line_idx-start), (tuple(line), after_real))

                        add_task(new_lines, rules, prev=task, step=new_step, cost=cost)

                # if nothing could be done with this line, try removing it
                # (maybe try removing in any case?)
                if line_normal in removes.keys() or not seen:
                    new_lines = lines[:line_idx] + lines[line_idx+1:]
                    new_rules = []
                    for old_start, old_end in rules:
                        new_start, new_end = (old_start - (0 if old_start <= line_idx else 1),
                                              old_end - (0 if old_end <= line_idx else 1))
                        if new_end > new_start:
                            new_rules.append((new_start, new_end))
                    new_step = ((rule_no, line_idx-start), (tuple(line), ()))
                    new_cost = removes[line_normal] if line_normal in removes.keys() else 0.9

                    add_task(new_lines, new_rules, prev=task, step=new_step, cost=new_cost)

            # try adding a line to this rule… would need to distinguish between
            # head/body lines in transforms
            for after, cost in inserts.items():
                mapping = map_vars([], after, [], rule_tokens)
                after_real = [t if t.type != 'VARIABLE' else Token('VARIABLE', mapping[t.val]) for t in after]
                after_real = tuple(after_real)
                new_lines = lines[:end] + (after_real,) + lines[end:]
                new_rules = []
                for old_start, old_end in rules:
                    new_rules.append((old_start + (0 if old_start < end else 1),
                                      old_end + (0 if old_end < end else 1)))
                new_step = ((rule_no, end-start), ((), after_real))

                add_task(new_lines, new_rules, prev=task, step=new_step, cost=cost)
            rule_no += 1

        # try adding a new fact
        if len(rules) < 2:
            for after, cost in inserts.items():
                new_lines = lines + (after,)
                new_rules = rules + (((len(lines), len(lines)+1)),)
                new_step = ((len(new_rules)-1, 0), (tuple(), tuple(after)))

                add_task(new_lines, new_rules, prev=task, step=new_step, cost=cost)

    return '', [], total_time, n_tested

def print_hint(solution, steps, fix_time, n_tested):
    if solution:
        print(colored('Hint found! Tested {} programs in {:.1f} s.'.format(n_tested, fix_time), 'green'))
        print(colored(' Edits', 'blue'))
        for line, (before, after) in steps:
            print('  {}:\t{} → {}'.format(line, stringify(before), stringify(after)))
        print(colored(' Final version', 'blue'))
        print(indent(compose(*decompose(solution)), 2))
    else:
        print(colored('Hint not found! Tested {} programs in {:.1f} s.'.format(n_tested, fix_time), 'red'))

# Find official solutions to all problems.
def init_problems():
    names = {}
    codes = {}
    libraries = {}

    pids = db.get_problem_ids()
    for pid in pids:
        names[pid], codes[pid], libraries[pid] = db.get_problem(pid)

    return names, codes, libraries

# Submit code to Prolog server for testing.
def test(name, code):
    # TODO also load fact library and solved predicates
    engine = PrologEngine(code=code)
    result = engine.ask("run_tests({}, '{}')".format(name, engine.id))
    engine.destroy()
    return result['event'] == 'success'

if __name__ == '__main__':
    # Get problem id from commandline.
    if len(sys.argv) < 2:
        print('usage: ' + sys.argv[0] + ' <pid>')
        sys.exit(1)
    pid = int(sys.argv[1])

    names, codes, libraries = init_problems()

    # Analyze traces for this problem to get edits, submissions and queries.
    traces = db.get_traces(pid)
    edits, lines, submissions, queries = get_edits_from_traces(traces.values())

    # Find incorrect submissions.
    incorrect = []
    for submission, count in sorted(submissions.items()):
        if not test(names[pid], submission):
            # This incorrect submission appeared in [count] attempts.
            incorrect += [submission]*count

    # XXX only for testing
    try:
        done = pickle.load(open('status-'+str(pid)+'.pickle', 'rb'))
    except:
        done = []

    # test fix() on incorrect student submissions
    if len(sys.argv) >= 3 and sys.argv[2] == 'test':
        timeout = int(sys.argv[3]) if len(sys.argv) >= 4 else 10

        print('Fixing {}/{} programs (timeout={})…'.format(
            len([p for p in incorrect if p not in done]), len(incorrect), timeout))

        for i, program in enumerate(incorrect):
            if program in done:
                continue
            print(colored('Analyzing program {0}/{1}…'.format(i+1, len(incorrect)), 'yellow'))
            print(indent(compose(*decompose(program)), 2))

            solution, steps, fix_time, n_tested = fix(names[pid], program, edits, timeout=timeout)
            if solution:
                done.append(program)
            print_hint(solution, steps, fix_time, n_tested)
            print()

            pickle.dump(done, open('status-'+str(pid)+'.pickle', 'wb'))

        print('Found hints for ' + str(len(done)) + ' of ' + str(len(incorrect)) + ' incorrect programs')

    # print info for this problem
    elif len(sys.argv) >= 3 and sys.argv[2] == 'info':
        # with no additional arguments, print some stats
        if len(sys.argv) == 3:
            print('Problem {} ({}): {} edits in {} traces, fixed {}/{} ({}/{} unique)'.format(
                pid, colored(names[pid], 'yellow'),
                colored(str(len(edits)), 'yellow'), colored(str(len(traces)), 'yellow'),
                colored(str(len([p for p in incorrect if p in done])), 'yellow'),
                colored(str(len(incorrect)), 'yellow'),
                colored(str(len(set(done))), 'yellow'),
                colored(str(len(set(incorrect))), 'yellow')))
        else:
            if sys.argv[3] == 'users':
                print(' '.join([str(uid) for (pid, uid) in sorted(traces.keys())]))
            # print all observed edits and their costs
            elif sys.argv[3] == 'edits':
                inserts, removes, changes = classify_edits(edits)
                print('Inserts')
                for after, cost in sorted(inserts.items(), key=lambda x: x[1]):
                    print(' {:.2f}\t{}'.format(cost, stringify(after)))
                print('Removes')
                for before, cost in sorted(removes.items(), key=lambda x: x[1]):
                    print(' {:.2f}\t{}'.format(cost, stringify(before)))
                print('Changes')
                for (before, after), cost in sorted(changes.items(), key=lambda x: x[1]):
                    print(' {:.2f}\t{} → {}'.format(cost,
                                                   stringify(before if before else [('INVALID', 'ε')]),
                                                   stringify(after if after else [('INVALID', 'ε')])))
            # print all student submissions not (yet) corrected
            elif sys.argv[3] == 'unsolved':
                for p in sorted(set(incorrect)):
                    if p in done:
                        continue
                    print(indent(compose(*decompose(p)), 2))
                    print()
            # print all student queries and their counts
            elif sys.argv[3] == 'queries':
                for query, count in queries.most_common():
                    print('  ' + str(count) + '\t' + query)

    # Print the edit graph in graphviz dot syntax.
    elif len(sys.argv) == 4 and sys.argv[2] == 'graph':
        uid = int(sys.argv[3])
        actions = parse(traces[(pid, uid)])

        nodes, submissions, queries = edit_graph(actions)

        def position(node):
            return (node.data[1]*150, node.data[0]*-60)

        def label(node):
            return stringify(node.data[2])

        def node_attr(node):
            if node.ein and node.data[2] == node.ein[0].data[2]:
                return 'color="gray", shape="point"'
            return ''

        def edge_attr(a, b):
            if a.data[2] == b.data[2]:
                return 'arrowhead="none"'
            return ''

        graphviz_str = graphviz(nodes, pos=position, label=label,
                                node_attr=node_attr, edge_attr=edge_attr)
        print(graphviz_str)

    # run interactive loop
    else:
        while True:
            # read the program from stdin
            print('Enter program, end with empty line:')
            code = ''
            try:
                while True:
                    line = input()
                    if not line:
                        break
                    code += line + '\n'
            except EOFError:
                break

            # try finding a fix
            print(colored('Analyzing program…', 'yellow'))
            solution, steps, fix_time, n_tested = fix(names[pid], code, edits, debug=True)
            print_hint(solution, steps, fix_time, n_tested)