summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--get-attributes.py97
1 files changed, 97 insertions, 0 deletions
diff --git a/get-attributes.py b/get-attributes.py
new file mode 100644
index 0000000..78258a3
--- /dev/null
+++ b/get-attributes.py
@@ -0,0 +1,97 @@
+import argparse
+import collections
+import os
+import random
+
+from monkey.patterns import get_patterns
+from prolog.util import parse as prolog_parse
+
+parser = argparse.ArgumentParser(description='Get patterns from student programs.')
+parser.add_argument('path', help='path to data directory')
+args = parser.parse_args()
+path = args.path.rstrip('/')
+
+name = os.path.basename(path)
+submissions = os.path.join(path, 'submissions')
+
+# select test/train users
+users = sorted([int(uid) for uid in os.listdir(submissions)])
+random.Random(0).shuffle(users)
+split = int(len(users)*0.7)
+learn_users = set(users[:split])
+test_users = set(users[split:])
+
+# save test users to file
+with open(path + '/users-test.txt', 'wt') as f:
+ for user in test_users:
+ print(user, file=f)
+
+# find test/train programs
+data = {
+ 'train': [],
+ 'test': []
+}
+for user in users:
+ user_dir = os.path.join(submissions, str(user))
+ user_subs = set()
+
+ # each submission is in a file named <seq. no>-<total tests>-<passed tests>
+ for submission in os.listdir(user_dir):
+ with open(os.path.join(user_dir, submission), 'r') as f:
+ code = f.read()
+
+ if code in user_subs: # do not add a program twice for the same user
+ continue
+ user_subs.add(code)
+
+ if prolog_parse(code) is None: # skip syntactically incorrect programs
+ continue
+ if name not in code: # only add programs with defined predicate
+ continue
+
+ seq, total, passed = submission.split('-')
+ data['test' if user in test_users else 'train'].append((code, total == passed))
+
+# print info about test users and test/train programs
+print('Test users:')
+print(test_users)
+print()
+for which in ['train', 'test']:
+ print('Programs ({}):'.format(which))
+ print('correct: {} ({} unique)'.format(
+ len([code for code, correct in data[which] if correct]),
+ len({code for code, correct in data[which] if correct})))
+ print('incorrect: {} ({} unique)'.format(
+ len([code for code, correct in data[which] if not correct]),
+ len({code for code, correct in data[which] if not correct})))
+ print()
+
+# extract attributes from training data
+patterns = collections.Counter()
+for code, correct in data['train']:
+ for pat, nodes in get_patterns(code):
+ patterns[pat] += 1
+
+attrs = []
+with open(path + '/attributes.tab', 'w') as pattern_file:
+ for i, (pat, count) in enumerate(patterns.most_common()):
+ if count < 5:
+ break
+ attrs.append(pat)
+ print('a{}\t{}'.format(i, pat), file=pattern_file)
+
+# check and write attributes for training/test data
+for t in ['train', 'test']:
+ with open(path + '/programs-{}.tab'.format(t), 'w') as f:
+ # print header
+ print('\t'.join(['code', 'correct'] + ['a'+str(i) for i in range(len(attrs))]), file=f)
+ print('\t'.join(['d'] * (len(attrs)+2)), file=f)
+ print('meta\tclass', file=f)
+
+ # print rows (program, correct, attr1, attr2, …)
+ for code, correct in data[t]:
+ record = '{}\t{}'.format(repr(code), 'T' if correct else 'F')
+ code_pats = [pat for pat, nodes in get_patterns(code)]
+ for pat in attrs:
+ record += '\t{}'.format('T' if pat in code_pats else 'F')
+ print(record, file=f)