summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--main.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/main.py b/main.py
index 610901f..777f770 100644
--- a/main.py
+++ b/main.py
@@ -33,6 +33,7 @@ def get_programs(path: str, names: str, do_canonicalize: bool = False):
continue
# canonicalize
+ original = code
if do_canonicalize:
code = canonicalize.canonicalize(code, given_names=names)
@@ -43,7 +44,7 @@ def get_programs(path: str, names: str, do_canonicalize: bool = False):
seq, total, passed = submission.split('-')
if code not in programs:
- programs[code] = {'users': set(), 'correct': total == passed}
+ programs[code] = {'users': set(), 'correct': total == passed, 'original': original}
programs[code]['users'].add(user)
return programs
@@ -64,21 +65,27 @@ if __name__ == '__main__':
attrs = collections.OrderedDict()
attrs.update(regex.get_attributes(programs))
attrs.update(dynamic.get_attributes(programs, args.exec, args.inputs))
- print('Attributes:', attrs.keys())
+
+ print('Attributes:')
+ for attr in attrs:
+ print(attr, attrs[attr]['desc'].to_string(inline=True))
for program in programs:
for attr in attrs:
programs[program][attr] = program in attrs[attr]['programs']
data = pandas.DataFrame.from_dict(programs, orient='index')
- y = data['correct']
- X = data.drop(['users', 'correct'], axis='columns')
+
+ train = data.sample(frac=0.7, random_state=0)
+ Y = train['correct']
+ X = train.drop(['users', 'correct', 'original'], axis='columns')
+ X_train, X_test, Y_train, Y_test = sklearn.model_selection.train_test_split(X, Y, test_size=0.33, random_state=0)
learners = collections.OrderedDict([
('major', sklearn.dummy.DummyClassifier()),
('tree', sklearn.tree.DecisionTreeClassifier()),
- ('rf', sklearn.ensemble.RandomForestClassifier()),
+ ('rf', sklearn.ensemble.RandomForestClassifier(n_estimators=100)),
])
for name, learner in learners.items():
- scores = sklearn.model_selection.cross_val_score(learner, X, y, cv=10)
+ scores = sklearn.model_selection.cross_val_score(learner, X, Y, cv=10)
print('{}:\t{}'.format(name, scores.mean()))