summaryrefslogtreecommitdiff
path: root/test-rules.py
blob: c47ca1049c22eebdc953b739e14e9fbd63c01f5f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
#!/usr/bin/python3

import argparse
import collections
import json
import os.path
import re
from statistics import mean

from termcolor import colored

from monkey.patterns import get_patterns

parser = argparse.ArgumentParser(description='Evaluate rules on student programs.')
parser.add_argument('path', help='path to data directory')
args = parser.parse_args()
data_dir = args.path

# klass: T/F
# condition: list of patterns
# distribution: rule class distribution
# quality: rule quality
class Rule(collections.namedtuple('Rule', ['klass', 'condition', 'distribution', 'quality'])):
    def __str__(self):
        s = 'Rule: class = {}, distribution = {}, quality = {}\n'.format(self.klass, self.distribution, self.quality)
        s += ''.join([str(pattern) + '\n' for pattern in self.condition])
        return s

# program: submitted code
# correct: does this submission pass all tests?
# patterns: patterns in this submission
# hint: suggested hint
class Submission(collections.namedtuple('Submission', ['program', 'correct', 'patterns', 'hint'])):
    pass

# ok: required patterns already in program (unused)
# remove: patterns that should be removed
# add: patterns that should be added (intersection from all relevant rules)
# add_alternatives: patterns that should be added (union from all relevant rules)
# NOTE currently either (only remove is set) or (both add and add_alternatives are set)
class Hint(collections.namedtuple('Hint', ['ok', 'remove', 'add', 'add_alternatives'])):
    pass

# read attributes
attributes_file = os.path.join(data_dir, 'attributes.tab')
attributes = dict([line.strip().split('\t') for line in open(attributes_file, 'r').readlines()])
attributes_ordered = [line.strip().split('\t')[1] for line in open(attributes_file, 'r').readlines()]

# read rules
rules_file = os.path.join(data_dir, 'rules.txt')
rules = []
for line in open(rules_file, 'r').readlines():
    match = re.match(r'IF ((?:a[0-9]*!=F(?: AND )*)*) THEN correct=([TF]) *\[ *([0-9]*) *([0-9]*)\] *([0-9.]*)', line.strip())
    if match:
        m = tuple(match.groups())
        condition = tuple(attributes[field[:-3]] for field in m[0].split(' AND '))
        rules.append(Rule(m[-4], condition, (int(m[-3]), int(m[-2])), float(m[-1])))
    else:
        print('Did not understand rule:', line.strip())

# export rules for tutor
json_file = os.path.join(data_dir, 'bugs.json')
json_data = {
    'patterns': attributes_ordered,
    'rules': [{
        'condition': r.condition,
        'class': r.klass == 'T',
        'distribution': r.distribution,
        'quality': r.quality,
    } for r in rules],
}
with open(json_file, 'w') as f:
    json.dump(json_data, f, sort_keys=True, indent=2)

def color_print(text, ranges):
    i = 0
    for start, length, color in sorted(ranges):
        # ignore overlapping ranges
        if start < i:
            continue
        print(text[i:start], end='')
        print(colored(text[start:start+length], color), end='')
        i = start + length
    print(text[i:])

# generate marks for selected patterns for color_print
def mark(patterns, selected, color):
    marks = set()
    for pattern, nodes in patterns:
        if pattern in selected:
            marks |= set((n[0].pos, len(n[0].val), color) for n in nodes if n[0].pos)
    return marks

# return a hint for the best applicable buggy rule
def suggest_buggy(rules, patterns):
    for rule in [r for r in rules if r.klass == 'F']:
        # suggest this rule if all patterns in condition are found in the program
        if all(rule_pattern in [p[0] for p in patterns] for rule_pattern in rule.condition):
            return Hint(ok=[], remove=rule.condition, add=[], add_alternatives=[])
    return None

# return a hint for the best applicable true rule
def suggest_true(rules, patterns):
    # get match info for all true rules
    rule_matches = collections.defaultdict(list)
    for rule in [r for r in rules if r.klass == 'T']:
        found = set()
        missing = set()
        for rule_pattern in rule.condition:
            if any(pattern == rule_pattern for pattern, nodes in patterns):
                found.add(rule_pattern)
            else:
                missing.add(rule_pattern)
        if missing:
            rule_matches[len(found)].append((found, missing))

    # return rules with most matching patterns
    for i in range(10, 0, -1):
        if i not in rule_matches:
            continue
        missing_patterns = collections.Counter()
        for found, missing in rule_matches[i]:
            for pattern in missing:
                missing_patterns[pattern] += 1

        best_missing_patterns = []
        for missing_pattern, count in missing_patterns.most_common():
            if count == missing_patterns.most_common()[0][1]:
                best_missing_patterns.append(missing_pattern)
            else:
                break

        add = []
        for pattern in attributes_ordered:
            if pattern in best_missing_patterns:
                add = [pattern]
                break
        add_alternatives = [pattern for pattern, _ in missing_patterns.most_common() if pattern not in add]
        return Hint(ok=[], remove=[], add=add, add_alternatives=add_alternatives)

    return None

# read traces
users_file = os.path.join(data_dir, 'users-test.txt')
users = [int(line.strip()) for line in open(users_file, 'r').readlines()]

# evaluate hints on student traces
submissions = collections.defaultdict(list)
for user in users:
    user_subs = []
    user_dir = os.path.join(data_dir, 'submissions', str(user))
    # each submission is in a file named <seq. no>-<total tests>-<passed tests>
    for submission in sorted(os.listdir(user_dir), key=lambda x: int(x.split('-')[0])):
        seq, total, passed = submission.split('-')
        correct = total == passed
        with open(os.path.join(user_dir, submission), 'r') as f:
            code = f.read()

        # check rules for this submission
        program_patterns = list(get_patterns(code))
        hint = suggest_buggy(rules, program_patterns)
        if not hint:
            hint = suggest_true(rules, program_patterns)
        user_subs.append(Submission(code, correct, program_patterns, hint))

        # skip submissions after the first correct program
        if correct:
            break

    # ignore traces with no / only correct submissions
    if (not any(s.correct for s in user_subs) or
        all(s.correct for s in user_subs)):
        continue

    submissions[user] = user_subs

    # print submissions with hints for debugging
    for s in user_subs:
        print('PASS' if s.correct else 'FAIL', end='\t')
        marks = []
        if s.hint and s.hint.remove:
            marks = mark(s.patterns, s.hint.remove, 'red')
        color_print(s.program, marks)

        if s.hint:
            for x in s.hint.remove:
                print('buggy\t', x)
            for x in s.hint.add:
                print('missing\t', x)
            for x in s.hint.add_alternatives:
                print('alternative\t', x)
        print()
    print('-'*30)
    print()

# submissions where hint pattern was implemented in the solution
good_hint = []
# submissions where one of the alternative hint patterns was implemented in the solution
medium_hint = []
# submissions where none of the hint patterns were implemented in the solution
bad_hint = []
# submissions with no suggestions
no_hint = []

# total number of submissions
n_subs = 0
for user, subs in submissions.items():
    solution = subs[-1]
    solution_patterns = [p[0] for p in solution.patterns]
    for s in subs[:-1]:
        n_subs += 1
        if s.hint:
            if s.hint.remove:
                # buggy rule: at least one pattern should not be present in solution
                if any(pattern not in solution_patterns for pattern in s.hint.remove):
                    good_hint.append(s)
                else:
                    bad_hint.append(s)
            else:
                # true rule: all patterns should be present in solution
                if all(pattern in solution_patterns for pattern in s.hint.add):
                    # best suggested pattern(s) match
                    good_hint.append(s)
                elif any(pattern in solution_patterns for pattern in s.hint.add_alternatives):
                    # some suggested pattern(s) match
                    medium_hint.append(s)
                else:
                    bad_hint.append(s)
        else:
            no_hint.append(s)

print('Statistics')
print('----------')
print('avg. submissions per trace:', mean(len(subs) for subs in submissions.values()))
print('avg. clauses in solution:', mean(subs[-1].program.count('.') for subs in submissions.values()))
print('total submissions:', n_subs)
print('positive hints (best implemented):', len([s for s in good_hint if s.hint.add]))
print('positive hints (alternative implemented):', len([s for s in medium_hint if s.hint.add_alternatives]))
print('positive hints (not implemented):', len([s for s in bad_hint if s.hint.add]))
print('buggy hints (implemented):', len([s for s in good_hint if s.hint.remove]))
print('buggy hints (not implemented):', len([s for s in bad_hint if s.hint.remove]))
print('no hints:', len(no_hint))