diff options
Diffstat (limited to 'monkey/patterns.py')
-rw-r--r-- | monkey/patterns.py | 511 |
1 files changed, 511 insertions, 0 deletions
diff --git a/monkey/patterns.py b/monkey/patterns.py new file mode 100644 index 0000000..ae958a0 --- /dev/null +++ b/monkey/patterns.py @@ -0,0 +1,511 @@ +# CodeQ: an online programming tutor. +# Copyright (C) 2015 UL FRI +# +# This program is free software: you can redistribute it and/or modify it under +# the terms of the GNU Affero General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) any +# later version. +# +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +# details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import collections +from itertools import combinations +import pickle +import random +import re +import sys + +from nltk import ParentedTree, Tree +from nltk.tgrep import tgrep_positions + +from prolog.util import Token, parse as prolog_parse + +# construct pattern to match the structure of nodes given by [include] +def pattern(node, include): + if isinstance(node, Token): + if node.type == 'NAME': + return '"{}"'.format(node.val) + return None + + if any(n is node for n in include): + if node.label() == 'variable': + return 'variable <: "{}"'.format(node[0].val) + return None + + label = node.label() + subpats = [pattern(child, include) for child in node] + + if label == 'functor': + return '{} <1 ({})'.format(label, subpats[0]) + + pat = None + if any(subpats): + if label == 'and': + if subpats[1]: + pat = subpats[1] + if subpats[0]: + if pat: + pat = subpats[0] + ' .. (' + pat + ')' + else: + pat = subpats[0] + elif label == 'args': + pat = label + for i, subpat in enumerate(subpats): + if subpat: + pat += ' <{} ({})'.format(i+1, subpat) + elif label == 'binop': + pat = label + if subpats[0]: + pat += ' <1 ({})'.format(subpats[0]) + pat += ' <2 ("{}")'.format(node[1].val) + if subpats[2]: + pat += ' <3 ({})'.format(subpats[2]) + elif label == 'clause': + pat = label + for i, subpat in enumerate(subpats): + if subpat: + pat += ' {} ({})'.format('<1' if i == 0 else '<<', subpats[i]) + elif label == 'compound': + if len(subpats) > 1 and subpats[1]: + pat = label + for i, subpat in enumerate(subpats): + pat += ' <{} ({})'.format(i+1, subpat) + else: + return None + elif label == 'head': + pat = label + pat += ' <1 ({})'.format(subpats[0]) + if not pat: + for s in subpats: + if s: + pat = s + break + return pat + +# construct a pattern connecting (variable) node [a] to [b] +def connect(a, b): + path_a = [] + node = a + while node.parent(): + node = node.parent() + path_a.insert(0, node) + + path_b = [] + node = b + while node.parent(): + node = node.parent() + path_b.insert(0, node) + + common_ancestor = None + for i in range(min(len(path_a), len(path_b))): + if path_a[i] is not path_b[i]: + break + if path_a[i].label() in {'compound', 'head'}: + break + common_ancestor = path_a[i] + + def node_label(node): + if node.label() == 'compound': + right = node[1] + nargs = 1 + while right._label == "args" and len(right) == 2: + right = right[1] + nargs += 1 + return 'compound <1 (functor <: "{}{}")'.format(node[0][0].val, nargs) + if node.label() == 'binop': + return 'binop <2 "{}"'.format(node[1].val) + return node.label() + i = 0 + while path_a[i].label() != 'clause': + i += 1 + + # path from top to common ancestor + pat = path_a[i].label() + i += 1 + n_top = 0 + while i < min(len(path_a), len(path_b)) and path_a[i] is path_b[i]: + node = path_a[i] + i += 1 + if node.label() == 'and': + continue + if node.parent().label() == 'and': + op = '<+(and)' + else: + op = '<{}'.format(node.parent_index()+1) + pat += ' {} ({}'.format(op, node_label(node)) + n_top += 1 + + path_a = path_a[i:] + path_b = path_b[i:] + + # path from common ancestor to [a] + n_a = 0 + for node in path_a: + if node.label() == 'and': + continue + op = '<' + if node.parent().label() == 'and': + op = '<+(and)' + elif node.parent_index() is not None: + op = '<{}'.format(node.parent_index()+1) + pat += ' {} ({}'.format(op, node_label(node)) + n_a += 1 + pat += ' <{} ({} <: "{}")'.format(a.parent_index()+1, a.label(), a[0].val) + pat += ')' * n_a + + # path from common ancestor to [b] + n_b = 0 + for node in path_b: + if node.label() == 'and': + continue + op = '<' + if node.parent().label() == 'and': + op = '<+(and)' + elif node.parent_index() is not None: + op = '<{}'.format(node.parent_index()+1) + pat += ' {} ({}'.format(op, node_label(node)) + n_b += 1 + pat += ' <{} ({} <: "{}")'.format(b.parent_index()+1, b.label(), b[0].val) + + pat += ')' * (n_top + n_b) + return pat + +# replace variable names with patterns and backlinks +def postprocess(pattern): + macros = '@ VAR /[A-Z]/; ' + #m = re.search(r'"[A-Z][^"]*_[0-9]+[^"]*"', pattern) + m = re.search(r'"[A-Z]_[0-9]+"', pattern) + nvars = 0 + while m is not None: + orig_name = m.group(0) + n = orig_name.strip('"').split("_")[1] + pat_name = 'v{}'.format(nvars) + nvars += 1 + pattern = pattern[:m.start()] + '@VAR={}{}'.format(pat_name, n) + pattern[m.end():] + for m in re.finditer(orig_name, pattern): + pattern = pattern[:m.start()] + '~{}{}'.format(pat_name, n) + pattern[m.end():] + m = re.search(r'"([A-Z]*_[0-9]+)"', pattern) + return macros + pattern + +def postprocess_simple(pattern): + pattern = postprocess(pattern) + if pattern.startswith("@ VAR /[A-Z]/; clause <2 (or"): + return None + #relevant = re.findall('(head|\(functor.*?\)|\(binop.*?".*?"|args|\(variable.*?\))', pattern) + relevant = re.findall('(head|\(functor.*?\)|\(binop.*?".*?"|args|variable|literal)', pattern) + + # elements + elements = [] + current = "" + i = 0 + par = 0 + while i < len(pattern): + if par > 0 and pattern[i] == ")": + par -= 1 + if par == 0: + elements.append(current) + current = "" + if par > 0 and pattern[i] == "(": + par += 1 + if par == 0 and \ + (pattern[i:].startswith("(head") or + pattern[i:].startswith("(compound") or + pattern[i:].startswith("(binop")): + par = 1 + if par > 0: + current += pattern[i] + i += 1 + # simplify variable + for ei, e in enumerate(elements): + # #elements[ei] = re.sub("\(variable.*?\)", "(variable)", e) + elements[ei] = "("+" ".join(re.findall('(head|\(functor.*?\)|\(binop.*?".*?"|args|variable|literal)', e))+")" + elements = sorted(elements) + #print(pattern) + #print(relevant) + #return "_".join(relevant)#pattern + return " ".join(elements)#pattern + +# construct pattern to match the structure of nodes given by [include], +# supports variables and literals +def pattern2(node, include): + if isinstance(node, Token): + return None + + label = node.label() + if any(n is node for n in include): + if label == 'literal': + return '"{}"'.format(node[0].val) + if label == 'variable': + return '{}'.format(label) + return None + if label == 'functor': + return '({} "{}")'.format(label, node[0].val) + + subpats = [pattern2(child, include) for child in node] + pat = None + if any(subpats): + if label == 'and': + if subpats[1]: + pat = subpats[1] + if subpats[0]: + if pat: + pat = subpats[0] + ' ' + pat + else: + pat = subpats[0] + elif label == 'args': + pat = label + for i, subpat in enumerate(subpats): + if subpat: + pat += ' {}'.format(subpat) + pat = '(' + pat + ')' + elif label == 'unop': + pat = '(' + label + ' ' + node[0].val + ' ' + subpats[1] + ')' + elif label == 'binop': + pat = label + if subpats[0]: + pat += ' {}'.format(subpats[0]) + pat += ' "{}"'.format(node[1].val) + if subpats[2]: + pat += ' {}'.format(subpats[2]) + pat = '(' + pat + ')' + elif label == 'clause': + pat = label + for i, subpat in enumerate(subpats): + if subpat: + pat += ' {}'.format(subpats[i]) + return '(' + pat + ')' + elif label == 'compound': + if len(subpats) > 1 and subpats[1]: + pat = label + for i, subpat in enumerate(subpats): + pat += ' {}'.format(subpat) + pat = '(' + pat + ')' + else: + return None + elif label == 'head': + pat = label + pat += ' {}'.format(subpats[0]) + pat = '(' + pat + ')' + elif label == 'list': + pat = 'list ' + if subpats[0]: + pat += '(h {})'.format(subpats[0]) + if subpats[0] and subpats[1]: + pat += ' ' + if subpats[1]: + pat += '(t {})'.format(subpats[1]) + pat = '(' + pat + ')' + if not pat: + for s in subpats: + if s: + pat = s + break + return pat + +def get_patterns(tree): + orig = tree + if isinstance(tree, str): + tree = prolog_parse(tree) + if tree is None: + return + tree = ParentedTree.convert(tree) + + # get patterns separately for each clause + for clause in tree: + all_variables = [] + variables = collections.defaultdict(list) + def walk(node): + if isinstance(node, Tree): + if node.label() == 'variable': + name = node[0].val + variables[name].append(node) + all_variables.append(node) + else: + for child in node: + walk(child) + walk(clause) + + # connecting pairs of nodes with same variable + for var, nodes in variables.items(): + for selected in combinations(nodes, 2): + pat = pattern2(clause, selected) + if pat: + print(pat) + yield pat, selected + #pat = connect(*selected) + #if pat: + # pp = postprocess_simple(pat) + # if pp: + # yield pp, selected + + # add singletons + for var, nodes in variables.items(): + if len(nodes) == 1: + pat = pattern2(clause, nodes) + if pat: + print(pat) + yield pat, nodes + #pat = pattern(tree, var) + #if pat: + # pp = postprocess_simple(pat) + # if pp: + # yield pp, nodes + + # add patterns for literal/variable pairs + literals = [node for node in clause.subtrees() if node.label() == 'literal'] + for literal in literals: + top = literal + while top != clause and top.parent().label() in {'compound', 'binop', 'unop', 'args', 'list'}: + top = top.parent() + variables = [node for node in top.subtrees() if node.label() == 'variable'] + for var in variables: + pat = pattern2(clause, [var, literal]) + if pat: + yield pat, [var, literal] + +# # connecting pairs of nodes with variables +# for selected in combinations(all_variables, 2): +# pat = connect(selected[0], selected[1]) +# if pat: +# yield postprocess(pat), selected + +# # using pairs of nodes with variables +# for selected in combinations(all_variables, 2): +# pat = pattern(tree, selected) +# if pat: +# yield postprocess(pat), selected + +# # using pairs of nodes with same variable +# for var, nodes in variables.items(): +# for selected in combinations(nodes, 2): +# pat = pattern(tree, selected) +# if pat: +# yield postprocess(pat), selected + +# # using each variable separately +# for var, nodes in variables.items(): +# pat = pattern(tree, nodes) +# if pat: +# yield postprocess(pat), nodes + +# # using each goal to select variables FIXME +# goals = [s for s in tree.subtrees() if s.label() in {'compound', 'binop'}] +# for goal in goals: +# goal_vars = {n.val for n in goal.leaves() if n.type == 'VARIABLE'} +# pat = pattern(tree, goal_vars) +# if pat: +# yield postprocess(pat) + +# nltk.tgrep does not play nice with non-Tree leaves +def _tgrep_prepare(tree): + if not isinstance(tree, ParentedTree): + tree = ParentedTree.convert(tree) + def prepare(node): + if isinstance(node, Token) or isinstance(node, str): + return ParentedTree(str(node), []) + return ParentedTree(node.label(), [prepare(child) for child in node]) + return prepare(tree) + +def find_motif(tree, motif): + tree = _tgrep_prepare(tree) + return tgrep_positions(motif, [tree]) + +# Extract edits and other data from existing traces for each problem. +# Run with: python3 -m monkey.patterns <problem ID> <solutions.pickle> +if __name__ == '__main__': + # Ignore traces from these users. + ignored_users = [ + 1, # admin + 231, # test + 360, # test2 + 358, # sasha + ] + + pid = int(sys.argv[1]) + name = sys.argv[2] + submissions = pickle.load(open('pickle/programs-{}.pickle'.format(pid), 'rb')) + + print('Analyzing programs for problem {}…'.format(pid)) + ndata = { + 'train': [], + 'test': [] + } + # get all users + users = set() + for code, data in submissions.items(): + users |= data['users'] + users = list(users) + users.sort() + random.Random(0).shuffle(users) + split = int(len(users)*0.7) + learn_users = set(users[:split]) + test_users = set(users[split:]) + print(test_users) + + for code, data in submissions.items(): + if len(code) > 1000 or prolog_parse(code) is None: + continue + if name not in code: + continue + ndata['train'] += [(code, data['n_tests'] == data['n_passed'])] * len(data['users'] & learn_users) + ndata['test'] += [(code, data['n_tests'] == data['n_passed'])] * len(data['users'] & test_users) + #programs += [(code, data['n_tests'] == data['n_passed'])] * len(data['users']) + + print('correct: {} ({} unique)'.format( + len([code for code, correct in ndata['train'] if correct]), + len({code for code, correct in ndata['train'] if correct}))) + print('incorrect: {} ({} unique)'.format( + len([code for code, correct in ndata['train'] if not correct]), + len({code for code, correct in ndata['train'] if not correct}))) + + #iprograms.sort() + #random.Random(0).shuffle(programs) + #split = int(len(programs)*0.7) + #data = { + # 'train': programs[:split], + # 'test': programs[split:], + #} + data = ndata + + # extract attributes from training data + patterns = collections.Counter() + for code, correct in data['train']: + for pat, nodes in get_patterns(code): + patterns[pat] += 1 + + attrs = [] + with open('data/attributes-{:03}.tab'.format(pid), 'w') as pattern_file: + for i, (pat, count) in enumerate(patterns.most_common()): + if count < 1: + break + attrs.append(pat) + print('a{}\t{}'.format(i, pat), file=pattern_file) + + # check and write attributes for training/test data + for t in ['train', 'test']: + with open('data/programs-{:03}-{}.tab'.format(pid, t), 'w') as f: + # print header + print('\t'.join(['code', 'correct'] + ['a'+str(i) for i in range(len(attrs))]), file=f) + print('\t'.join(['d'] * (len(attrs)+2)), file=f) + print('meta\tclass', file=f) + for code, correct in data[t]: + record = '{}\t{}'.format(repr(code), 'T' if correct else 'F') + + ## check attributes by using tgrep to find patterns + #tree = _tgrep_prepare(prolog_parse(code)) + #for pat in attrs: + # matches = list(tgrep_positions(pat, [tree])) + # record += '\t{}'.format(bool(matches[0])) + + # check attributes by enumerating patterns in code + code_pats = [pat for pat, nodes in get_patterns(code)] + for pat in attrs: + record += '\t{}'.format('T' if pat in code_pats else 'F') + + print(record, file=f) |