summaryrefslogtreecommitdiff
path: root/monkey/edits.py
blob: 56bf773bdd58994a7ce832b515bce18a438d4ab8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
# CodeQ: an online programming tutor.
# Copyright (C) 2015 UL FRI
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
# details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import collections
import math

from .action import expand, parse
from prolog.util import normalized, parse as prolog_parse, rename_vars_ast, rename_vars_list, interesting_ranges, stringify, tokenize
from .util import avg, logistic

def get_edits_from_trace(trace, test, id):
    submissions = set()         # Program versions at 'test' actions.
    queries = set()             # Queries run by the student.

    # For each observed edit, store a list of features (with repeats) of ASTs
    # where they were observed.
    edits = collections.defaultdict(set)
    def add_edit(path, start, end, tree):
        if start == end:
            return
        if len(end) > 2*len(start):
            return
        edits[(path, start, end)].add(id)

    # Parse trace actions and ensure there is a separate action for each
    # inserted/removed character.
    try:
        actions = parse(trace)
        expand(actions)
    except:
        # Only a few traces fail to parse, so just skip them.
        actions = []

    # State variables.
    open_edits = []
    code_next = ''    # Program code after applying the current action.
    done = False      # Set to True on first correct version.
    prev_tree = None
    prev_action = None

    for action_id, action in enumerate(actions):
        code = code_next
        code_next = action.apply(code)

        if action.type in {'prolog_solve', 'test'}:
            if action.type == 'prolog_solve':
                queries.add(action.query.rstrip(' .'))
            elif action.type == 'test':
                passed, total = test(code)
                correct = passed == total
                submissions.add((code, correct))
                if correct:
                    # Ignore actions after the first correct version.
                    done = True
                    break

            tree = prolog_parse(code)
            if tree and tree.leaves() and tree != prev_tree:
                for terminals, path in interesting_ranges(tree):
                    pos_start = terminals[0].pos
                    pos_end = terminals[-1].pos + len(terminals[-1].val)
                    # If there is an open edit with the same range, don't add a new one.
                    found = False
                    for e_start_tree, e_start_tokens, e_path, e_pos_start, e_pos_end in open_edits:
                        if e_pos_start == pos_start and e_pos_end == pos_end:
                            found = True
                            break
                    if not found:
                        #print('OPENING {}'.format(terminals))
                        open_edits.append([tree, terminals, path, pos_start, pos_end])
            prev_tree = tree

        if action.type in {'insert', 'remove'}:
            new_open_edits = []
            for start_tree, start_tokens, path, pos_start, pos_end in open_edits:
                new_pos_start, new_pos_end = pos_start, pos_end
                if action.type == 'remove':
                    if action.offset < pos_end:
                        new_pos_end -= 1
                        if action.offset < pos_start:
                            new_pos_start -= 1
                elif action.type == 'insert':
                    if action.offset < pos_start:
                        new_pos_start += 1
                        new_pos_end += 1
                    elif action.offset == pos_start:
                        new_pos_end += 1
                    elif action.offset < pos_end:
                        new_pos_end += 1
                    elif action.offset == pos_end:
                        if (prev_action is None or
                                prev_action.type == 'insert' and prev_action.offset == action.offset-1 or
                                prev_action.type == 'remove' and prev_action.offset == action.offset):
                            orig_next = None
                            for terminal in start_tree.leaves():
                                if terminal.pos >= start_tokens[-1].pos + len(start_tokens[-1].val):
                                    orig_next = terminal
                                    break
                            if not (orig_next and orig_next.val[0] == action.text):
                                new_pos_end += 1
                if new_pos_start != new_pos_end:
                    new_open_edits.append([start_tree, start_tokens, path, new_pos_start, new_pos_end])
            open_edits = new_open_edits
            prev_action = action

    if done:
        for start_tree, start_tokens, path, pos_start, pos_end in open_edits:
            end_tokens = tokenize(code[pos_start:pos_end])
            names = {}
            start_normal = rename_vars_list(start_tokens, names)
            end_normal = rename_vars_list(end_tokens, names)
            norm_tree = rename_vars_ast(start_tree, names)
            add_edit(path, tuple(start_normal), tuple(end_normal), norm_tree)

    return edits, submissions, queries

def get_edits_from_solutions(solutions, test):
    # For each observed edit, store a list of features (with repeats) of ASTs
    # where they were observed.
    submissions = collections.defaultdict(set)
    queries = collections.Counter()
    edits = collections.defaultdict(set)

    for solution in solutions:
        trace = solution.trace
        uid = solution.codeq_user_id
        trace_edits, trace_submissions, trace_queries = get_edits_from_trace(trace, test, uid)

        # Update edits.
        for edit, uids in trace_edits.items():
            edits[edit] |= uids

        # Update submission/query counters (use normalized variables).
        for code, correct in trace_submissions:
            code = stringify(rename_vars_list(tokenize(code)))
            submissions[(code, correct)].add(uid)
        for query in trace_queries:
            code = stringify(rename_vars_list(tokenize(query)))
            queries[code] += 1

    # Discard edits that only occur in one trace.
    singletons = [edit for edit in edits if len(edits[edit]) < 2]
    for edit in singletons:
        del edits[edit]

    n_start = collections.Counter()
    for (path, a, b), uids in edits.items():
        edits[(path, a, b)] = (len(uids), uids)
        n_start[(path, a)] += len(uids)

    # Find the probability of each edit a → b.
    new_edits = {}
    for (path, a, b), (count, uids) in edits.items():
        if a != b:
            p = count / n_start[(path, a)]
            new_edits[(path, a, b)] = (p, uids)
    edits = new_edits

    # Tweak the edit distribution to improve search.
    if edits:
        avg_p = avg([v[0] for v in edits.values()])
        for edit, (p, uids) in edits.items():
            edits[edit] = (logistic(p, k=3, x_0=avg_p), uids)

    return edits, submissions, queries

def classify_edits(edits):
    inserts = {}
    removes = {}
    changes = {}
    for (before, after), cost in edits.items():
        if after and not before:
            inserts[after] = cost
        elif before and not after:
            removes[before] = cost
        else:
            changes[(before, after)] = cost
    return inserts, removes, changes


# Extract edits and other data from existing traces for each problem.
if __name__ == '__main__':
    import pickle
    from db.models import Problem, Solution
    from db.util import make_identifier
    from prolog.util import used_predicates
    from server.problems import get_facts, load_problem, solutions_for_problems

    # Ignore traces from these users.
    ignored_users = [
        1, # admin
        231, # test
        360, # test2
        358, # sasha
    ]

    edits, submissions, queries = {}, {}, {}
    try:
        test_results = pickle.load(open('test_results.pickle', 'rb'))
    except:
        test_results = collections.defaultdict(dict)

    for problem in Problem.list():
        pid = problem.id
        solutions = [s for s in Solution.filter(problem_id=pid, done=True)
                       if s.codeq_user_id not in ignored_users]
        if not solutions:
            print('No traces for {}'.format(problem.identifier))
            continue

        # Testing function.
        problem_module = load_problem(problem.language, problem.group, problem.identifier, 'common')
        other_problems = [p for p in Problem.filter_language(problem.language)
                            if p.identifier != problem.identifier]
        facts = get_facts(problem.language, problem_module)
        def test(code):
            # Find solutions to other problems that are used by this program.
            used_predicate_identifiers = {make_identifier(name) for name in used_predicates(code)}
            dependencies = sorted([p[2:] for p in other_problems
                                         if p.identifier in used_predicate_identifiers])

            # Check for cached results.
            normal_code = stringify(rename_vars_list(tokenize(code)))
            code_key = (normal_code, tuple(dependencies))
            if code_key not in test_results[pid]:
                aux_code = '\n' + solutions_for_problems(problem.language, dependencies) + '\n' + facts
                n_correct, n_all, _ = problem_module.test(code, aux_code)
                test_results[pid][code_key] = (n_correct, n_all)
            return test_results[pid][code_key]

        print('Analyzing traces for {}… '.format(problem.identifier), end='', flush=True)
        print('{} traces… '.format(len(solutions)), end='', flush=True)
        try:
            edits[pid], submissions[pid], queries[pid] = get_edits_from_solutions(solutions, test)
            print('{} edits, {} submissions, {} queries'.format(
                len(edits[pid]), len(submissions[pid]), len(queries[pid])))
        except Exception as ex:
            import traceback
            traceback.print_exc()

    pickle.dump((edits, submissions, queries), open('edits.pickle', 'wb'))
    pickle.dump(test_results, open('test_results.pickle', 'wb'))