summaryrefslogtreecommitdiff
path: root/get-attributes.py
blob: 78258a399131222f6d4a93c9b6238302bd7d8769 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import collections
import os
import random

from monkey.patterns import get_patterns
from prolog.util import parse as prolog_parse

parser = argparse.ArgumentParser(description='Get patterns from student programs.')
parser.add_argument('path', help='path to data directory')
args = parser.parse_args()
path = args.path.rstrip('/')

name = os.path.basename(path)
submissions = os.path.join(path, 'submissions')

# select test/train users
users = sorted([int(uid) for uid in os.listdir(submissions)])
random.Random(0).shuffle(users)
split = int(len(users)*0.7)
learn_users = set(users[:split])
test_users = set(users[split:])

# save test users to file
with open(path + '/users-test.txt', 'wt') as f:
    for user in test_users:
        print(user, file=f)

# find test/train programs
data = {
    'train': [],
    'test': []
}
for user in users:
    user_dir = os.path.join(submissions, str(user))
    user_subs = set()

    # each submission is in a file named <seq. no>-<total tests>-<passed tests>
    for submission in os.listdir(user_dir):
        with open(os.path.join(user_dir, submission), 'r') as f:
            code = f.read()

        if code in user_subs:  # do not add a program twice for the same user
            continue
        user_subs.add(code)

        if prolog_parse(code) is None:  # skip syntactically incorrect programs
            continue
        if name not in code:  # only add programs with defined predicate
            continue

        seq, total, passed = submission.split('-')
        data['test' if user in test_users else 'train'].append((code, total == passed))

# print info about test users and test/train programs
print('Test users:')
print(test_users)
print()
for which in ['train', 'test']:
    print('Programs ({}):'.format(which))
    print('correct: {} ({} unique)'.format(
        len([code for code, correct in data[which] if correct]),
        len({code for code, correct in data[which] if correct})))
    print('incorrect: {} ({} unique)'.format(
        len([code for code, correct in data[which] if not correct]),
        len({code for code, correct in data[which] if not correct})))
    print()

# extract attributes from training data
patterns = collections.Counter()
for code, correct in data['train']:
    for pat, nodes in get_patterns(code):
        patterns[pat] += 1

attrs = []
with open(path + '/attributes.tab', 'w') as pattern_file:
    for i, (pat, count) in enumerate(patterns.most_common()):
        if count < 5:
            break
        attrs.append(pat)
        print('a{}\t{}'.format(i, pat), file=pattern_file)

# check and write attributes for training/test data
for t in ['train', 'test']:
    with open(path + '/programs-{}.tab'.format(t), 'w') as f:
        # print header
        print('\t'.join(['code', 'correct'] + ['a'+str(i) for i in range(len(attrs))]), file=f)
        print('\t'.join(['d'] * (len(attrs)+2)), file=f)
        print('meta\tclass', file=f)

        # print rows (program, correct, attr1, attr2, …)
        for code, correct in data[t]:
            record = '{}\t{}'.format(repr(code), 'T' if correct else 'F')
            code_pats = [pat for pat, nodes in get_patterns(code)]
            for pat in attrs:
                record += '\t{}'.format('T' if pat in code_pats else 'F')
            print(record, file=f)