From 651e2be4480b19ac486cb8a4dd2fb08b448ebc67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Mo=C5=BEina?= Date: Tue, 17 Jan 2017 20:12:29 +0100 Subject: Added scripts for learning rules. --- abml/evaluate.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 abml/evaluate.py (limited to 'abml/evaluate.py') diff --git a/abml/evaluate.py b/abml/evaluate.py new file mode 100644 index 0000000..2e318fb --- /dev/null +++ b/abml/evaluate.py @@ -0,0 +1,75 @@ +import pickle +import argparse +import Orange +from Orange.evaluation import TestOnTestData, CA, AUC, LogLoss +import abml.rules_prolog as rp + +import orangecontrib.evcrules.logistic as logistic +import orangecontrib.abml.abrules as rules +import orangecontrib.abml.argumentation as arg + +parser = argparse.ArgumentParser(description='Learn and test rules for prolog programs.') +parser.add_argument('Name', type=str, help='Predicate name.') +args = parser.parse_args() +name = args.Name + +# load data +data = Orange.data.Table('data/{}/programs-train'.format(name)) + +# create learner +rule_learner = rp.Rules4Prolog(name, 0.9) + + + +# learn a classifier +classifier = rule_learner(data) + +# save model +fmodel = open("data/{}/model.txt".format(name), "wt") +for r in classifier.rule_list: + print(r, r.curr_class_dist, r.quality) + fmodel.write("{} dist={} quality={}\n".format(str(r), str(r.curr_class_dist), r.quality)) + +# accuracy of model +testdata = Orange.data.Table('data/{}/programs-test'.format(name)) +predictions = classifier(testdata) +acc = 0 +for i, p in enumerate(predictions): + acc += p == testdata.Y[i] +acc /= len(testdata) +print("Accuracy on test data: ", acc) +predictions = classifier(data) +acc = 0 +for i, p in enumerate(predictions): + acc += p == data.Y[i] +acc /= len(data) +print("Accuracy on train data: ", acc) + +# test model + other methodsstrong_piece_attack defends_around_king +bayes = Orange.classification.NaiveBayesLearner() +logistic = Orange.classification.LogisticRegressionLearner() +tree = Orange.classification.TreeLearner() +random_forest = Orange.classification.RandomForestLearner() +svm = Orange.classification.SVMLearner() +cn2 = Orange.classification.rules.CN2UnorderedLearner() +learners = [rule_learner, logistic, bayes, cn2, tree, random_forest, svm] +res = TestOnTestData(data, testdata, learners) +ca = CA(res) +auc = AUC(res) +ll = LogLoss(res) + +names = ['logrules', 'logistic', 'naive-bayes', 'cn2', 'tree', 'random-forest', 'svm'] +scores = "" +scores += "CA\tAUC\tLogLoss\tMethod\n" +for ni, n in enumerate(names): + scores += "{}\t{}\t{}\t{}\n".format(ca[ni], auc[ni], ll[ni], n) +print(scores) +fscores = open("data/{}/scores.txt".format(name), "wt") +fscores.write(scores) + +all_rules = classifier.rule_list +all_rules.sort(key = lambda r: r.quality, reverse=True) +rfile = open("data/{}/rules.txt".format(name), "wt") +for r in all_rules: + print(r, r.curr_class_dist, r.quality) + rfile.write("{} {} {}\n".format(r, r.curr_class_dist, r.quality)) -- cgit v1.2.1