summaryrefslogtreecommitdiff
path: root/monkey
diff options
context:
space:
mode:
authorTimotej Lazar <timotej.lazar@araneo.org>2015-01-15 12:10:22 +0100
committerAleš Smodiš <aless@guru.si>2015-08-11 14:26:01 +0200
commit819ab10281c9bd6c000364c3a243959edd18abf7 (patch)
tree5ca3452418b49781563221bb56cf70e1d0fb1bb8 /monkey
parentd86793039957aa408a98806aecfb5964bda5fb87 (diff)
Move pymonkey stuff to monkey/
Importing pymonkey into webmonkey, let's see how this works.
Diffstat (limited to 'monkey')
-rwxr-xr-xmonkey/action.py284
-rw-r--r--monkey/db.py44
-rw-r--r--monkey/edits.py307
-rw-r--r--monkey/graph.py67
-rwxr-xr-xmonkey/monkey.py313
-rw-r--r--monkey/prolog/engine.py135
-rw-r--r--monkey/prolog/lexer.py90
-rw-r--r--monkey/prolog/util.py156
-rw-r--r--monkey/util.py81
9 files changed, 1477 insertions, 0 deletions
diff --git a/monkey/action.py b/monkey/action.py
new file mode 100755
index 0000000..057ed0e
--- /dev/null
+++ b/monkey/action.py
@@ -0,0 +1,284 @@
+#!/usr/bin/python3
+
+class Action:
+ # type ∈ ['remove', 'insert', 'solve', 'solve_all']
+ # time: absolute elapsed time since the attempt started, in ms
+ # offset: position of the first inserted/removed character
+ # text: inserted/removed text or query
+ # total, passed: number of test cases
+ def __init__(self, type, time, offset=0, text='', total=0, passed=0):
+ self.type = type
+ self.time = time
+ if type == 'insert' or type == 'remove':
+ self.offset = offset
+ self.length = len(text)
+ self.text = text
+ elif type == 'solve' or type == 'solve_all':
+ self.query = text
+ elif type == 'test':
+ self.total = total
+ self.passed = passed
+
+ def __str__(self):
+ s = 't = ' + str(self.time/1000.0) + ' ' + self.type
+ if self.type in ('insert', 'remove'):
+ s += ' "' + self.text.replace('\n', '\\n').replace('\t', '\\t') + '" at ' + str(self.offset)
+ elif self.type in ('solve', 'solve_all'):
+ s += ' "' + self.query + '"'
+ elif self.type == 'test':
+ s += ' {0} / {1}'.format(self.passed, self.total)
+ return s
+
+ # apply this action to text
+ def apply(self, text):
+ if self.type == 'insert':
+ return text[:self.offset] + self.text + text[self.offset:]
+ elif self.type == 'remove':
+ return text[:self.offset] + text[self.offset+self.length:]
+ else:
+ return text
+
+ # reverse the application of this action
+ def unapply(self, text):
+ if self.type == 'insert':
+ return text[:self.offset] + text[self.offset+self.length:]
+ elif self.type == 'remove':
+ return text[:self.offset] + self.text + text[self.offset:]
+ else:
+ return text
+
+# parse log from database into a list of actions, cleaning up some fluff.
+# ignore non-text actions (queries and tests)
+def parse(data):
+ if data == None:
+ return [], []
+
+ actions = []
+ incorrect = set()
+
+ time = 0
+ code = ''
+ i = 0
+ while i < len(data):
+ # parse one action
+ type = data[i]
+ i += 1
+ dt = int(((data[i] << 8) + (data[i+1])) * 100.0)
+ time += dt
+ i += 2
+ if type == 1: # insert
+ offset = (data[i] << 8) + data[i+1]
+ i += 2
+ length = (data[i] << 8) + data[i+1]
+ i += 2
+ text = data[i:i+length].decode()
+ i += length
+ action = Action('insert', time, offset=offset, text=text)
+ elif type == 2: # remove
+ offset = (data[i] << 8) + data[i+1]
+ i += 2
+ length = (data[i] << 8) + data[i+1]
+ i += 2
+ text = code[offset:offset+length]
+ action = Action('remove', time, offset=offset, text=text)
+ elif type == 3 or type == 4: # solve / solve all
+ length = (data[i] << 8) + data[i+1]
+ i += 2
+ query = data[i:i+length].decode()
+ i += length
+ act_type = 'solve' + ('_all' if type == 4 else '')
+ action = Action(act_type, time, text=query)
+ elif type == 8: # test
+ total = data[i]
+ i += 1
+ passed = data[i]
+ i += 1
+ action = Action('test', time, total=total, passed=passed)
+ else:
+ # unsupported action type
+ continue
+
+ # skip normalization if this is the first action
+ if actions == []:
+ actions.append(action)
+ code = action.apply(code)
+ continue
+
+ # add to list of actions; modify previously added action if necessary
+ prev = actions[-1]
+
+ # remove superfluous REMOVE action when newline is inserted (due to editor auto-indent)
+ if prev.type == 'remove' and action.type == 'insert' and \
+ action.time == prev.time and \
+ action.offset == prev.offset and action.length > prev.length and \
+ action.text[action.length-prev.length:] == prev.text:
+ # discard last REMOVE action
+ code = prev.unapply(code)
+ actions.pop()
+
+ # replace current action with something better
+ length = action.length - prev.length
+ new = Action('insert', prev.time, offset = prev.offset, text = action.text[:length])
+ actions.append(new)
+ code = new.apply(code)
+
+ # remove superfluous INSERT action when newline is removed (due to editor auto-indent)
+ elif prev.type == 'remove' and action.type == 'insert' and \
+ action.time == prev.time and \
+ action.offset == prev.offset and action.length < prev.length and \
+ prev.text[prev.length-action.length:] == action.text:
+ # discard last INSERT action
+ code = prev.unapply(code)
+ actions.pop()
+
+ # replace current action with something better
+ length = prev.length - action.length
+ new = Action('remove', prev.time, offset = prev.offset, text = prev.text[:length])
+ actions.append(new)
+ code = new.apply(code)
+
+ # discard INSERT/REMOVE pairs (typos)
+ elif prev.type == 'insert' and action.type == 'remove' and \
+ action.time - prev.time < 10000 and \
+ action.offset == prev.offset and action.text == prev.text:
+ # discard last and current actions
+ code = prev.unapply(code)
+ actions.pop()
+
+ # discard REMOVE/INSERT pairs (deleted char then typed back)
+ elif prev.type == 'remove' and action.type == 'insert' and \
+ action.offset == prev.offset and action.text == prev.text:
+ # discard last and current actions
+ code = prev.unapply(code)
+ actions.pop()
+
+ # otherwise, simply append the current action
+ else:
+ actions.append(action)
+ code = action.apply(code)
+
+ return actions
+
+# expand any multi-char actions (does not do anything for the 99%)
+def expand(actions):
+ i = 0
+ while i < len(actions):
+ if actions[i].type == 'insert' and len(actions[i].text) > 1:
+ a = actions.pop(i)
+ for offset in range(len(a.text)):
+ actions.insert(i+offset, Action('insert', a.time, a.offset+offset, a.text[offset]))
+ i += len(a.text)
+ elif actions[i].type == 'remove' and len(actions[i].text) > 1:
+ a = actions.pop(i)
+ for offset in range(len(a.text)):
+ actions.insert(i, Action('remove', a.time, a.offset+offset, a.text[offset]))
+ i += len(a.text)
+ else:
+ i += 1
+
+# each action in parse() result corresponds to single insertion/deletion.
+# this function merges related adjacent actions
+def compress(actions):
+ # first make each edit change exactly one character, for easier handling
+ expand(actions)
+
+ i = 0
+ while i < len(actions)-1:
+ a = actions[i]
+ b = actions[i+1]
+
+ # merge adjacent INSERT actions
+ # +a +b → +ab
+ if a.type == 'insert' and b.type == 'insert' and \
+ b.offset == a.offset + a.length: #and b.time - a.time < 10000:
+ a.text += b.text
+ a.length += b.length
+ del actions[i+1]
+
+ # merge adjacent REMOVE actions (two cases: backspace & delete)
+ # -b -a → -ab
+ elif a.type == 'remove' and b.type == 'remove' and \
+ b.offset == a.offset - b.length: #and b.time - a.time < 10000:
+ a.text = b.text + a.text
+ a.offset = b.offset
+ a.length += b.length
+ del actions[i+1]
+ # -a -b → -ab
+ elif a.type == 'remove' and b.type == 'remove' and \
+ b.offset == a.offset and b.time - a.time < 10000:
+ a.text += b.text
+ a.length += b.length
+ del actions[i+1]
+
+ # merge adjacent INSERT/REMOVE actions
+ # +ab -b → +a
+ elif a.type == 'insert' and b.type == 'remove' and \
+ b.offset >= a.offset and b.offset < a.offset + a.length and \
+ b.length == a.offset + a.length - b.offset and b.time - a.time < 10000:
+ del_start = b.offset - a.offset
+ del_end = del_start + b.length
+ a.text = a.text[:del_start] + a.text[del_end:]
+ a.length -= b.length
+ del actions[i+1]
+
+ else:
+ i += 1
+
+# some sample code
+if __name__ == '__main__':
+ import sys
+ if len(sys.argv) < 2:
+ print('usage: ' + sys.argv[0] + ' <database>')
+ sys.exit(1)
+
+ import sqlite3
+ conn = sqlite3.connect(sys.argv[1])
+ conn.text_factory = bytes
+ c = conn.cursor()
+
+ # print all problem ids
+ print('problems:')
+ c.execute('select * from problems')
+ for problem in c.fetchall():
+ # problem = (id, name, description, details, solution, library)
+ # name: predicate name + arity (e.g. conc/2)
+ # desc: one-line problem description
+ # details: detailed problem description
+ # solution: official solution
+ # library: fact database for testing (e.g. for parent, brother, … relations)
+ print(' ' + str(problem[0]) + '\t' + str(problem[1], encoding='utf-8'))
+ print()
+
+ pid = input('enter problem id: ')
+ c.execute('select id from attempts where problem=?', (pid,))
+ attempts = list(c.fetchall())
+
+ # print all attempt ids for the selected problem
+ print('attempts for problem ' + str(pid) + ':')
+ print(', '.join([str(attempt[0]) for attempt in attempts]))
+ print()
+
+ aid = input('enter attempt id: ')
+ c.execute('select * from attempts where id=?', (aid,))
+ attempt = c.fetchone()
+ # attempt = (id, problem_id, user_id, log, content, done, session)
+ # log: action sequence log
+ # content: final version for this attempt
+ # done: did any version of the program pass all tests?
+ # session: irrelevant
+ try:
+ actions = parse(attempt[3])
+ print('read ' + str(len(actions)) + ' actions from log')
+ compress(actions)
+ print('after compression: ' + str(len(actions)) + ' actions')
+ print()
+
+ print('code versions for this attempt:')
+ code = ''
+ for action in actions:
+ code = action.apply(code)
+ print(action)
+ print(code.strip())
+ print()
+ except Exception as ex:
+ sys.stderr.write('Error parsing action log: ' + str(ex))
diff --git a/monkey/db.py b/monkey/db.py
new file mode 100644
index 0000000..0634098
--- /dev/null
+++ b/monkey/db.py
@@ -0,0 +1,44 @@
+#!/usr/bin/python3
+
+import sqlite3
+
+db = sqlite3.connect('misc/solutions.db')
+db.row_factory = sqlite3.Row
+db.text_factory = bytes
+cursor = db.cursor()
+
+def b_to_utf8(bytestring):
+ return str(bytestring, encoding='utf-8')
+
+def get_problem_ids():
+ cursor.execute('SELECT id FROM problems ORDER BY id ASC')
+ return [row['id'] for row in cursor]
+
+def get_problem(pid):
+ cursor.execute('SELECT name, solution, library FROM problems WHERE id=?', (pid,))
+ row = cursor.fetchone()
+ name = b_to_utf8(row['name'])
+ solution = b_to_utf8(row['solution']).replace('\r', '')
+ lib_id = row['library'] if row['library'] else None
+ return name, solution, lib_id
+
+def get_depends(pid):
+ cursor.execute('SELECT dependency FROM depends WHERE problem=?', (pid,))
+ return [r['dependency'] for r in cursor.fetchall()]
+
+def get_library(lid):
+ cursor.execute('SELECT facts FROM libraries WHERE id=?', (lid,))
+ row = cursor.fetchone()
+ return b_to_utf8(row['facts']).replace('\r', '') if row else None
+
+def get_tests(pid):
+ cursor.execute('SELECT query FROM tests WHERE problem=?', (pid,))
+ return [b_to_utf8(row['query']) for row in cursor]
+
+def get_traces(pid):
+ cursor.execute('SELECT * FROM attempts WHERE problem=? AND done=1 ORDER BY id ASC', (pid,))
+ return {(pid, attempt['user']): attempt['log'] for attempt in cursor}
+
+def get_solved(uid):
+ cursor.execute('SELECT problem FROM attempts WHERE user=? AND done=1 ORDER BY problem ASC', (uid,))
+ return [row['problem'] for row in cursor.fetchall()]
diff --git a/monkey/edits.py b/monkey/edits.py
new file mode 100644
index 0000000..fef591a
--- /dev/null
+++ b/monkey/edits.py
@@ -0,0 +1,307 @@
+#!/usr/bin/python3
+
+import collections
+
+from action import expand, parse
+from graph import Node
+from prolog.util import rename_vars, stringify, tokenize
+from util import get_line
+
+# A line edit is a contiguous sequences of actions within a single line. This
+# function takes a sequence of actions and builds a directed acyclic graph
+# where each edit represents one line edit and each node represents a version
+# of some line. The function returns a list of nodes (first element is the
+# root), and sets of submissions (program versions tested by the user) and
+# queries in this attempt.
+def edit_graph(actions, debug=False):
+ # Return values.
+ nodes = [Node([0, 0, ()])] # Node data: rank (Y), line no. (X), and tokens.
+ submissions = set() # Program versions at 'test' actions.
+ queries = set() # Queries run by the student.
+
+ # State variables.
+ leaves = {0: nodes[0]} # Current leaf node for each line.
+ rank = 1 # Rank (order / y-position) for the next node.
+ code_next = '' # Program code after applying the current action.
+
+ # Ensure there is a separate action for each inserted/removed character.
+ expand(actions)
+ for action_id, action in enumerate(actions):
+ code = code_next
+ code_next = action.apply(code)
+
+ if action.type == 'test':
+ submissions.add(code)
+
+ elif action.type == 'solve' or action.type == 'solve_all':
+ queries.add(action.query)
+
+ elif action.type == 'insert' or action.type == 'remove':
+ # Number of the changed line.
+ line = code[:action.offset].count('\n')
+ # Current leaf node for this line.
+ parent = leaves[line]
+ # Tokens in this line after applying [action].
+ tokens_next = tuple(tokenize(get_line(code_next, line)))
+
+ # If a new node is inserted, clone each leaf into the next rank.
+ # This makes it easier to print the graph for graphviz; when
+ # analyzing the graph, duplicate nodes without siblings should be
+ # ignored.
+ new_leaves = {}
+
+ if action.text == '\n':
+ if action.type == 'insert':
+ tokens_next_right = tuple(tokenize(get_line(code_next, line+1)))
+
+ child_left = Node([rank, line, tokens_next])
+ parent.add_out(child_left)
+
+ child_right = Node([rank, line+1, tokens_next_right])
+ parent.add_out(child_right)
+
+ # Create new leaf nodes.
+ for i, leaf in leaves.items():
+ if i < line:
+ new_leaves[i] = Node([rank, i, leaf.data[2]])
+ leaf.add_out(new_leaves[i])
+ elif i > line:
+ new_leaves[i+1] = Node([rank, i+1, leaf.data[2]])
+ leaf.add_out(new_leaves[i+1])
+ new_leaves[line] = child_left
+ new_leaves[line+1] = child_right
+
+ elif action.type == 'remove':
+ parent_right = leaves[line+1]
+
+ child = Node([rank, line, tokens_next])
+ parent_right.add_out(child)
+ parent.add_out(child)
+
+ # Create new leaf nodes.
+ for i, leaf in leaves.items():
+ if i < line:
+ new_leaves[i] = Node([rank, i, leaf.data[2]])
+ leaf.add_out(new_leaves[i])
+ elif i > line+1:
+ new_leaves[i-1] = Node([rank, i-1, leaf.data[2]])
+ leaf.add_out(new_leaves[i-1])
+ new_leaves[line] = child
+ else:
+ # Skip the node if the next action is insert/remove (except \n)
+ # on the same line.
+ if action_id < len(actions)-1:
+ action_next = actions[action_id+1]
+ if action_next.type in ('insert', 'remove'):
+ line_next = code_next[:action_next.offset].count('\n')
+ if action_next.text != '\n' and line == line_next:
+ continue
+
+ # Skip the node if it is the same as the parent.
+ if tokens_next == parent.data[2]:
+ continue
+
+ child = Node([rank, line, tokens_next])
+ parent.add_out(child)
+
+ # Create new leaf nodes.
+ for i, leaf in leaves.items():
+ if i != line:
+ new_leaves[i] = Node([rank, i, leaf.data[2]])
+ leaf.add_out(new_leaves[i])
+ new_leaves[line] = child
+
+ leaves = new_leaves
+ nodes += leaves.values()
+ rank += 1
+
+ return nodes, submissions, queries
+
+# Return all interesting edit paths in the edit graph rooted at [root].
+def get_paths(root, path=tuple(), done=None):
+ if done is None:
+ done = set()
+
+ cur_path = list(path)
+ if len(path) == 0 or path[-1] != root.data[2]:
+ cur_path.append(root.data[2])
+
+ # leaf node
+ if len(root.eout) == 0:
+ yield tuple(cur_path)
+ # empty node
+ elif len(path) > 1 and len(root.data[2]) == 0:
+ yield tuple(cur_path)
+
+ if len(root.data[2]) > 0:
+ new_path = cur_path
+ else:
+ new_path = [root.data[2]]
+ done.add(root)
+
+ for node in root.eout:
+ if node not in done:
+ yield from get_paths(node, tuple(new_path), done)
+
+# Build an edit graph for each trace and find "meaningful" (to be defined)
+# edits. Return a dictionary of edits and their frequencies, and also
+# submissions and queries in [traces].
+def get_edits_from_traces(traces):
+ # Helper function to remove trailing punctuation from lines. This is a
+ # rather ugly performance-boosting hack.
+ def remove_punct(line):
+ if line and line[-1].type in ('COMMA', 'PERIOD', 'SEMI', 'FROM'):
+ return line[:-1]
+ return line
+
+ # Return values: counts for observed edits, lines, submissions and queries.
+ edits = collections.Counter()
+ lines = collections.Counter()
+ submissions = collections.Counter()
+ queries = collections.Counter()
+
+ for trace in traces:
+ try:
+ actions = parse(trace)
+ except:
+ continue
+ nodes, trace_submissions, trace_queries = edit_graph(actions)
+
+ # Update the submissions/queries counters; rename variables first to
+ # remove trivial differences.
+ for submission in trace_submissions:
+ tokens = tokenize(submission)
+ rename_vars(tokens)
+ code = stringify(tokens)
+ submissions[code] += 1
+
+ for query in trace_queries:
+ tokens = tokenize(query)
+ rename_vars(tokens)
+ code = stringify(tokens)
+ queries[code] += 1
+
+ # Get edits.
+ for path in get_paths(nodes[0]):
+ for i in range(len(path)):
+ start = list(remove_punct(path[i]))
+ var_names = rename_vars(start)
+ start_t = tuple(start)
+
+ for j in range(len(path[i+1:])):
+ end = list(remove_punct(path[i+1+j]))
+ rename_vars(end, var_names)
+ end_t = tuple(end)
+
+ if start_t != end_t:
+ edit = (start_t, end_t)
+ edits[edit] += 1
+ lines[start_t] += 1
+
+ # Discard rarely occurring edits. XXX only for testing
+ singletons = [edit for edit in edits if edits[edit] < 2]
+ for edit in singletons:
+ lines[edit[0]] -= edits[edit]
+ del edits[edit]
+
+ # Get the probability of each edit given its 'before' line.
+ for before, after in edits:
+ edits[(before, after)] /= lines[before]
+
+ # Normalize line frequencies.
+ if len(lines) > 0:
+ lines_max = max(lines.values())
+ lines = {line: count/lines_max for line, count in lines.items()}
+
+ return edits, lines, submissions, queries
+
+def classify_edits(edits):
+ inserts = {}
+ removes = {}
+ changes = {}
+ for (before, after), cost in edits.items():
+ if after and not before:
+ inserts[after] = cost
+ elif before and not after:
+ removes[before] = cost
+ else:
+ changes[(before, after)] = cost
+ return inserts, removes, changes
+
+# Simplify an edit graph with given nodes: remove empty leaf nodes and other
+# fluff. The function is called recursively until no more changes are done.
+def clean_graph(nodes):
+ changed = False
+
+ # A
+ # | --> A (when B is an empty leaf)
+ # B
+ for node in nodes:
+ if len(node.eout) == 0 and len(node.ein) == 1 and len(node.data[2]) == 0:
+ parent = node.ein[0]
+ parent.eout.remove(node)
+ nodes.remove(node)
+ changed = True
+ break
+
+ # A
+ # | --> A
+ # A
+ for node in nodes:
+ if len(node.ein) == 1:
+ parent = node.ein[0]
+ if len(parent.eout) == 1 and node.data[2] == parent.data[2]:
+ parent.eout = node.eout
+ for child in parent.eout:
+ child.ein = [parent if node == node else node for node in child.ein]
+ nodes.remove(node)
+ changed = True
+ break
+
+ # A A
+ # |\ |
+ # | C --> | (when C is empty)
+ # |/ |
+ # B B
+ for node in nodes:
+ if len(node.data[2]) == 0 and len(node.ein) == 1 and len(node.eout) == 1:
+ parent = node.ein[0]
+ child = node.eout[0]
+ if len(parent.eout) == 2 and len(child.ein) == 2:
+ parent.eout = [n for n in parent.eout if n != node]
+ child.ein = [n for n in child.ein if n != node]
+ nodes.remove(node)
+ changed = True
+ break
+
+ # A
+ # |
+ # C --> A
+ # |
+ # A
+ for node in nodes:
+ if len(node.data[2]) == 0 and len(node.ein) == 1 and len(node.eout) == 1:
+ parent = node.ein[0]
+ child = node.eout[0]
+ if len(parent.eout) == 1 and len(child.ein) == 1 and parent.data[2] == child.data[2]:
+ parent.eout = [child]
+ child.ein = [parent]
+ nodes.remove(node)
+ changed = True
+ break
+
+ if changed:
+ # go again until nothing changes
+ clean_graph(nodes)
+ else:
+ # compact node ranks
+ ranks = set([node.data[0] for node in nodes])
+ missing = set(range(1,max(ranks)+1)) - ranks
+
+ for node in nodes:
+ diff = 0
+ for rank in sorted(missing):
+ if rank >= node.data[0]:
+ break
+ diff += 1
+ node.data[0] -= diff
diff --git a/monkey/graph.py b/monkey/graph.py
new file mode 100644
index 0000000..5bf78ec
--- /dev/null
+++ b/monkey/graph.py
@@ -0,0 +1,67 @@
+#!/usr/bin/python3
+
+class Node(object):
+ def __init__(self, data):
+ self.data = data
+ self.ein = []
+ self.eout = []
+
+ # (Re-)insert a child node [target] to [self] at index [idx] (or as the
+ # rightmost child if index is not given). Also append [self] to the list of
+ # parents of [target].
+ def add_out(self, target, idx=None):
+ if target in self.eout:
+ self.eout.remove(target)
+ if idx is None:
+ self.eout.append(target)
+ else:
+ self.eout.insert(idx, target)
+ if self not in target.ein:
+ target.ein.append(self)
+ return target
+
+ def __repr__(self):
+ return str(self.data)
+
+ def __lt__(self, other):
+ return self.data < other.data
+
+# Print the edit graph containing [nodes] in graphviz dot format. The [label]
+# and [pos] functions determine node labels and coordinates (x,y), and the
+# [node_attr] and [edge_attr] functions specify additional attributes for each
+# node and edge. To actually use the coordinates returned by [pos], generate
+# the image using neato -n1.
+def graphviz(nodes, label=str, pos=None, node_attr=None, edge_attr=None):
+ # Generate node descriptions.
+ node_str = ''
+ node_id = {}
+ for node in nodes:
+ node_id[node] = len(node_id)
+ node_str += '\t{} [label="{}"'.format(node_id[node], label(node).replace('"', '\\"'))
+ if pos:
+ node_str += ', ' + 'pos="{},{}"'.format(*pos(node))
+ if node_attr:
+ node_str += ', ' + node_attr(node)
+ node_str += '];\n'
+
+ # Generate edge descriptions (breadth-first).
+ edge_str = ''
+ for node in nodes:
+ a = node_id[node]
+ for child in node.eout:
+ b = node_id[child]
+ edge_str += '\t{} -> {}'.format(a, b)
+ if edge_attr:
+ edge_str += ' [' + edge_attr(node, child) + ']'
+ edge_str += ';\n'
+
+ output = 'digraph G {\n'
+ output += '\tordering="out";\n'
+ output += '\tnode [shape="box", margin="0.05,0", fontname="sans", fontsize=13.0];\n'
+ output += '\n'
+ output += node_str
+ output += '\n'
+ output += edge_str
+ output += '}\n'
+
+ return output
diff --git a/monkey/monkey.py b/monkey/monkey.py
new file mode 100755
index 0000000..42c81f4
--- /dev/null
+++ b/monkey/monkey.py
@@ -0,0 +1,313 @@
+#!/usr/bin/python3
+
+import collections
+import math
+import pickle
+import sys
+import time
+
+from termcolor import colored
+
+import db
+from action import parse
+from edits import classify_edits, clean_graph, edit_graph, get_edits_from_traces
+from graph import Node, graphviz
+from prolog.engine import PrologEngine
+from prolog.util import compose, decompose, map_vars, rename_vars, stringify
+from util import PQueue, Token, indent
+
+# score a program (a list of lines) according to lines distribution
+def score(program, lines):
+ result = 1
+ for line in program:
+ line_normal = list(line)
+ rename_vars(line_normal)
+ line_normal = tuple(line_normal)
+ result *= lines.get(line_normal, 0.01)
+
+ if len(program) == 0 or result == 0:
+ return 0.01
+ return math.pow(result, 1/len(program))
+
+# find a sequence of edits that fixes [code]
+def fix(name, code, edits, timeout=30, debug=False):
+ todo = PQueue() # priority queue of candidate solutions
+ done = set() # set of already-analyzed solutions
+
+ # Add a new candidate solution ([lines]+[rules]) to the priority queue.
+ # This solution is generated by applying [step] with [cost] to [prev] task.
+ def add_task(lines, rules, prev=None, step=None, cost=None):
+ if prev is None:
+ path = ()
+ path_cost = 1.0
+ else:
+ path = tuple(list(prev[1]) + [step])
+ path_cost = prev[2] * cost
+ todo.push(((tuple(lines), tuple(rules)), path, path_cost), -path_cost)
+
+ lines, rules = decompose(code)
+ add_task(lines, rules)
+
+ inserts, removes, changes = classify_edits(edits)
+ start_time = time.monotonic()
+ n_tested = 0
+ while True:
+ total_time = time.monotonic() - start_time
+ if total_time > timeout:
+ break
+
+ task = todo.pop()
+ if task == None:
+ break
+
+ (lines, rules), path, path_cost = task
+ code = compose(lines, rules)
+ if code in done:
+ continue
+ done.add(code)
+
+ if debug:
+ print('Cost {:.12f}'.format(path_cost))
+ for line, (before, after) in path:
+ print('line ' + str(line) + ':\t' + stringify(before) + ' → ' + stringify(after))
+
+ # if the code is correct, we are done
+ try:
+ if test(name, code):
+ return code, path, total_time, n_tested
+ except:
+ pass
+ n_tested += 1
+
+ # otherwise generate new solutions
+ rule_no = 0
+ for start, end in rules:
+ rule = lines[start:end]
+ rule_tokens = [t for line in rule for t in line]
+
+ for line_idx in range(start, end):
+ line = lines[line_idx]
+
+ line_normal = list(line)
+ rename_vars(line_normal)
+ line_normal = tuple(line_normal)
+
+ seen = False
+ for (before, after), cost in changes.items():
+ if line_normal == before:
+ seen = True
+ mapping = map_vars(before, after, line, rule_tokens)
+ after_real = tuple([t if t.type != 'VARIABLE' else Token('VARIABLE', mapping[t.val]) for t in after])
+ new_lines = lines[:line_idx] + (after_real,) + lines[line_idx+1:]
+ new_step = ((rule_no, line_idx-start), (tuple(line), after_real))
+
+ add_task(new_lines, rules, prev=task, step=new_step, cost=cost)
+
+ # if nothing could be done with this line, try removing it
+ # (maybe try removing in any case?)
+ if line_normal in removes.keys() or not seen:
+ new_lines = lines[:line_idx] + lines[line_idx+1:]
+ new_rules = []
+ for old_start, old_end in rules:
+ new_start, new_end = (old_start - (0 if old_start <= line_idx else 1),
+ old_end - (0 if old_end <= line_idx else 1))
+ if new_end > new_start:
+ new_rules.append((new_start, new_end))
+ new_step = ((rule_no, line_idx-start), (tuple(line), ()))
+ new_cost = removes[line_normal] if line_normal in removes.keys() else 0.9
+
+ add_task(new_lines, new_rules, prev=task, step=new_step, cost=new_cost)
+
+ # try adding a line to this rule… would need to distinguish between
+ # head/body lines in transforms
+ for after, cost in inserts.items():
+ mapping = map_vars([], after, [], rule_tokens)
+ after_real = [t if t.type != 'VARIABLE' else Token('VARIABLE', mapping[t.val]) for t in after]
+ after_real = tuple(after_real)
+ new_lines = lines[:end] + (after_real,) + lines[end:]
+ new_rules = []
+ for old_start, old_end in rules:
+ new_rules.append((old_start + (0 if old_start < end else 1),
+ old_end + (0 if old_end < end else 1)))
+ new_step = ((rule_no, end-start), ((), after_real))
+
+ add_task(new_lines, new_rules, prev=task, step=new_step, cost=cost)
+ rule_no += 1
+
+ # try adding a new fact
+ if len(rules) < 2:
+ for after, cost in inserts.items():
+ new_lines = lines + (after,)
+ new_rules = rules + (((len(lines), len(lines)+1)),)
+ new_step = ((len(new_rules)-1, 0), (tuple(), tuple(after)))
+
+ add_task(new_lines, new_rules, prev=task, step=new_step, cost=cost)
+
+ return '', [], total_time, n_tested
+
+def print_hint(solution, steps, fix_time, n_tested):
+ if solution:
+ print(colored('Hint found! Tested {} programs in {:.1f} s.'.format(n_tested, fix_time), 'green'))
+ print(colored(' Edits', 'blue'))
+ for line, (before, after) in steps:
+ print(' {}:\t{} → {}'.format(line, stringify(before), stringify(after)))
+ print(colored(' Final version', 'blue'))
+ print(indent(compose(*decompose(solution)), 2))
+ else:
+ print(colored('Hint not found! Tested {} programs in {:.1f} s.'.format(n_tested, fix_time), 'red'))
+
+# Find official solutions to all problems.
+def init_problems():
+ names = {}
+ codes = {}
+ libraries = {}
+
+ pids = db.get_problem_ids()
+ for pid in pids:
+ names[pid], codes[pid], libraries[pid] = db.get_problem(pid)
+
+ return names, codes, libraries
+
+# Submit code to Prolog server for testing.
+def test(name, code):
+ # TODO also load fact library and solved predicates
+ engine = PrologEngine(code=code)
+ result = engine.ask("run_tests({}, '{}')".format(name, engine.id))
+ engine.destroy()
+ return result['event'] == 'success'
+
+if __name__ == '__main__':
+ # Get problem id from commandline.
+ if len(sys.argv) < 2:
+ print('usage: ' + sys.argv[0] + ' <pid>')
+ sys.exit(1)
+ pid = int(sys.argv[1])
+
+ names, codes, libraries = init_problems()
+
+ # Analyze traces for this problem to get edits, submissions and queries.
+ traces = db.get_traces(pid)
+ edits, lines, submissions, queries = get_edits_from_traces(traces.values())
+
+ # Find incorrect submissions.
+ incorrect = []
+ for submission, count in sorted(submissions.items()):
+ if not test(names[pid], submission):
+ # This incorrect submission appeared in [count] attempts.
+ incorrect += [submission]*count
+
+ # XXX only for testing
+ try:
+ done = pickle.load(open('status-'+str(pid)+'.pickle', 'rb'))
+ except:
+ done = []
+
+ # test fix() on incorrect student submissions
+ if len(sys.argv) >= 3 and sys.argv[2] == 'test':
+ timeout = int(sys.argv[3]) if len(sys.argv) >= 4 else 10
+
+ print('Fixing {}/{} programs (timeout={})…'.format(
+ len([p for p in incorrect if p not in done]), len(incorrect), timeout))
+
+ for i, program in enumerate(incorrect):
+ if program in done:
+ continue
+ print(colored('Analyzing program {0}/{1}…'.format(i+1, len(incorrect)), 'yellow'))
+ print(indent(compose(*decompose(program)), 2))
+
+ solution, steps, fix_time, n_tested = fix(names[pid], program, edits, timeout=timeout)
+ if solution:
+ done.append(program)
+ print_hint(solution, steps, fix_time, n_tested)
+ print()
+
+ pickle.dump(done, open('status-'+str(pid)+'.pickle', 'wb'))
+
+ print('Found hints for ' + str(len(done)) + ' of ' + str(len(incorrect)) + ' incorrect programs')
+
+ # print info for this problem
+ elif len(sys.argv) >= 3 and sys.argv[2] == 'info':
+ # with no additional arguments, print some stats
+ if len(sys.argv) == 3:
+ print('Problem {} ({}): {} edits in {} traces, fixed {}/{} ({}/{} unique)'.format(
+ pid, colored(names[pid], 'yellow'),
+ colored(str(len(edits)), 'yellow'), colored(str(len(traces)), 'yellow'),
+ colored(str(len([p for p in incorrect if p in done])), 'yellow'),
+ colored(str(len(incorrect)), 'yellow'),
+ colored(str(len(set(done))), 'yellow'),
+ colored(str(len(set(incorrect))), 'yellow')))
+ else:
+ if sys.argv[3] == 'users':
+ print(' '.join([str(uid) for (pid, uid) in sorted(traces.keys())]))
+ # print all observed edits and their costs
+ elif sys.argv[3] == 'edits':
+ inserts, removes, changes = classify_edits(edits)
+ print('Inserts')
+ for after, cost in sorted(inserts.items(), key=lambda x: x[1]):
+ print(' {:.2f}\t{}'.format(cost, stringify(after)))
+ print('Removes')
+ for before, cost in sorted(removes.items(), key=lambda x: x[1]):
+ print(' {:.2f}\t{}'.format(cost, stringify(before)))
+ print('Changes')
+ for (before, after), cost in sorted(changes.items(), key=lambda x: x[1]):
+ print(' {:.2f}\t{} → {}'.format(cost,
+ stringify(before if before else [('INVALID', 'ε')]),
+ stringify(after if after else [('INVALID', 'ε')])))
+ # print all student submissions not (yet) corrected
+ elif sys.argv[3] == 'unsolved':
+ for p in sorted(set(incorrect)):
+ if p in done:
+ continue
+ print(indent(compose(*decompose(p)), 2))
+ print()
+ # print all student queries and their counts
+ elif sys.argv[3] == 'queries':
+ for query, count in queries.most_common():
+ print(' ' + str(count) + '\t' + query)
+
+ # Print the edit graph in graphviz dot syntax.
+ elif len(sys.argv) == 4 and sys.argv[2] == 'graph':
+ uid = int(sys.argv[3])
+ actions = parse(traces[(pid, uid)])
+
+ nodes, submissions, queries = edit_graph(actions)
+
+ def position(node):
+ return (node.data[1]*150, node.data[0]*-60)
+
+ def label(node):
+ return stringify(node.data[2])
+
+ def node_attr(node):
+ if node.ein and node.data[2] == node.ein[0].data[2]:
+ return 'color="gray", shape="point"'
+ return ''
+
+ def edge_attr(a, b):
+ if a.data[2] == b.data[2]:
+ return 'arrowhead="none"'
+ return ''
+
+ graphviz_str = graphviz(nodes, pos=position, label=label,
+ node_attr=node_attr, edge_attr=edge_attr)
+ print(graphviz_str)
+
+ # run interactive loop
+ else:
+ while True:
+ # read the program from stdin
+ print('Enter program, end with empty line:')
+ code = ''
+ try:
+ while True:
+ line = input()
+ if not line:
+ break
+ code += line + '\n'
+ except EOFError:
+ break
+
+ # try finding a fix
+ print(colored('Analyzing program…', 'yellow'))
+ solution, steps, fix_time, n_tested = fix(names[pid], code, edits, debug=True)
+ print_hint(solution, steps, fix_time, n_tested)
diff --git a/monkey/prolog/engine.py b/monkey/prolog/engine.py
new file mode 100644
index 0000000..dff577c
--- /dev/null
+++ b/monkey/prolog/engine.py
@@ -0,0 +1,135 @@
+#!/usr/bin/python3
+
+import http.client
+import json
+import re
+import urllib
+
+address, port = 'localhost', 3030
+
+class PrologEngine(object):
+ def __init__(self, address=address, port=port, code='', destroy=False, id=None):
+ self.conn = http.client.HTTPConnection(address, port, timeout=10)
+
+ # If existing engine ID is given, use it.
+ if id:
+ self.id = id
+ return
+
+ # Otherwise, create a new engine.
+ hdrs = {'Content-Type': 'application/json;charset=utf-8'}
+ opts = json.dumps({'destroy': destroy, 'src_text': code, 'format': 'json-s'})
+ reply, outputs = self.request('POST', '/pengine/create', body=opts, headers=hdrs)
+
+ failed = (reply['event'] != 'create')
+ warnings = []
+ errors = []
+ for output in outputs:
+ print(output)
+ message = PrologEngine.parse_prolog_output(output)
+ if output['message'] == 'warning':
+ warnings.append(message)
+ elif output['message'] == 'error':
+ failed = True
+ errors.append(message)
+
+ if failed:
+ raise Exception('\n'.join(errors))
+
+ self.id = reply['id']
+
+ def send(self, event):
+ params = urllib.parse.urlencode({
+ 'id': self.id,
+ 'event': event,
+ 'format': 'json-s'})
+ reply, outputs = self.request('GET', '/pengine/send?' + params)
+ return reply
+
+ def ask(self, query):
+ event = 'ask(({}),[])'.format(query)
+ reply = self.send(event)
+ return reply
+
+ def next(self, n=1):
+ event = 'next({})'.format(n)
+ reply = self.send(event)
+ return reply
+
+ def stop(self):
+ return self.send('stop')
+
+ def destroy(self):
+ reply = self.send('destroy')
+ self.id = None
+ self.conn.close()
+ self.conn = None
+
+ # Return the main reply and possible output replies.
+ def request(self, method, path, body=None, headers={}):
+ self.conn.request(method, path, body, headers)
+ outputs = []
+ while True:
+ response = self.conn.getresponse()
+ if response.status != http.client.OK:
+ raise Exception('server returned {}'.format(response.status))
+ reply = json.loads(response.read().decode('utf-8'))
+ self.id = reply['id']
+ if reply['event'] == 'output':
+ outputs.append(reply)
+ params = urllib.parse.urlencode({
+ 'id': self.id,
+ 'format': 'json-s'})
+ self.conn.request('GET', '/pengine/pull_response?' + params)
+ else:
+ return reply, outputs
+
+ # Check if output is an error message and return a prettified version of it.
+ def parse_prolog_output(output):
+ match = re.match(r'.*<pre class="[^"]*">(.*)</pre>.*',
+ output['data'], flags=re.DOTALL)
+ data = match.group(1).strip()
+ message = ''
+ if output['message'] == 'error':
+ if 'location' in output:
+ loc = output['location']
+ message += 'near line ' + str(loc['line'])
+ if 'ch' in loc:
+ message += ', character ' + str(loc['ch'])
+ message += ': '
+
+ if output.get('code') == 'syntax_error':
+ match = re.match(r'^.*Syntax error: (.*)$', data, flags=re.DOTALL)
+ message += match.group(1)
+ elif output.get('code') == 'permission_error':
+ match = re.match(r'^.*(No permission [^\n]*)', data, flags=re.DOTALL)
+ message += match.group(1)
+ elif output.get('code') == 'type_error':
+ match = re.match(r'^.*(Type error: [^\n]*)', data, flags=re.DOTALL)
+ message += match.group(1)
+ else:
+ message += data
+
+ # Replace anonymous variable names with _.
+ message = re.sub(r'_G[0-9]*', '_', message)
+ return message
+
+def test(name, code):
+ engine = PrologEngine(code=code)
+ reply = engine.ask("run_tests({}, '{}', Results)".format(name, engine.id))
+ engine.destroy()
+
+ if reply['event'] != 'success':
+ raise Exception('testing procedure failed')
+
+ results = re.findall(r'(?:success|failure)\([^)]*\)', reply['data'][0]['Results'])
+ n_total = len(results)
+ n_passed = len([r for r in results if r.startswith('success')])
+ return (n_passed, n_total)
+
+# Basic sanity check.
+if __name__ == '__main__':
+ engine = PrologEngine(code='dup([],[]). dup([H|T],[H,H|TT]) :- dup(T,TT).')
+ print('engine id is ' + engine.id)
+ print(engine.ask("run_tests({},'{}',Result)".format('dup/2', engine.id)))
+ engine.destroy()
diff --git a/monkey/prolog/lexer.py b/monkey/prolog/lexer.py
new file mode 100644
index 0000000..971e8a6
--- /dev/null
+++ b/monkey/prolog/lexer.py
@@ -0,0 +1,90 @@
+#!/usr/bin/python3
+
+import ply.lex as lex
+
+# LEXER
+
+#states = (
+# ('comment', 'exclusive'),
+#)
+
+# tokens; treat operators as names if followed by (
+operators = {
+ r':-': 'FROM',
+ r'->': 'IMPLIES',
+ r'\+': 'NOT',
+ r'not': 'NOT',
+ r'=': 'EQU',
+ r'\=': 'NEQU',
+ r'==': 'EQ',
+ r'\==': 'NEQ',
+ r'=..': 'UNIV',
+ r'is': 'IS',
+ r'=:=': 'EQA',
+ r'=\=': 'NEQA',
+ r'<': 'LT',
+ r'=<': 'LE',
+ r'>': 'GT',
+ r'>=': 'GE',
+ r'@<': 'LTL',
+ r'@=<': 'LEL',
+ r'@>': 'GTL',
+ r'@>=': 'GEL',
+ r'+': 'PLUS',
+ r'-': 'MINUS',
+ r'*': 'STAR',
+ r'/': 'DIV',
+ r'//': 'IDIV',
+ r'mod': 'MOD',
+ r'**': 'POW',
+ r'.': 'PERIOD',
+ r',': 'COMMA',
+ r';': 'SEMI'
+}
+tokens = list(operators.values()) + [
+ 'UINTEGER', 'UREAL',
+ 'NAME', 'VARIABLE', 'STRING',
+ 'LBRACKET', 'RBRACKET', 'LPAREN', 'RPAREN', 'PIPE', 'LBRACE', 'RBRACE',
+ 'INVALID'
+]
+
+# punctuation
+t_LBRACKET = r'\['
+t_RBRACKET = r'\]'
+t_LPAREN = r'\('
+t_RPAREN = r'\)'
+t_PIPE = r'\|'
+t_LBRACE = r'{'
+t_RBRACE = r'}'
+
+t_UINTEGER = r'[0-9]+'
+t_UREAL = r'[0-9]+\.[0-9]+([eE][-+]?[0-9]+)?|inf|nan'
+t_VARIABLE = r'(_|[A-Z])[a-zA-Z0-9_]*'
+t_STRING = r'"(""|\\.|[^\"])*"'
+
+# no support for nested comments yet
+def t_comment(t):
+ r'(/\*(.|\n)*?\*/)|(%.*)'
+ pass
+
+def t_NAME(t):
+ r"'(''|\\.|[^\\'])*'|[a-z][a-zA-Z0-9_]*|[-+*/\\^<>=~:.?@#$&]+|!|;|,"
+ if t.lexer.lexpos >= len(t.lexer.lexdata) or t.lexer.lexdata[t.lexer.lexpos] != '(':
+ t.type = operators.get(t.value, 'NAME')
+ return t
+
+t_ignore = ' \t'
+
+def t_newline(t):
+ r'\n+'
+ t.lexer.lineno += len(t.value)
+
+def t_error(t):
+ # TODO send this to stderr
+ #print("Illegal character '" + t.value[0] + "'")
+ t.type = 'INVALID'
+ t.value = t.value[0]
+ t.lexer.skip(1)
+ return t
+
+lexer = lex.lex()
diff --git a/monkey/prolog/util.py b/monkey/prolog/util.py
new file mode 100644
index 0000000..0ab3c8b
--- /dev/null
+++ b/monkey/prolog/util.py
@@ -0,0 +1,156 @@
+#!/usr/bin/python3
+
+import itertools
+import math
+import re
+
+from .lexer import lexer
+from util import Token
+
+def tokenize(text):
+ lexer.input(text)
+ return [Token(t.type, t.value, t.lexpos) for t in lexer]
+
+operators = set([
+ 'FROM', 'IMPLIES', 'NOT',
+ 'EQU', 'NEQU', 'EQ', 'NEQ', 'UNIV', 'IS', 'EQA', 'NEQA',
+ 'LT', 'LE', 'GT', 'GE', 'LTL', 'LEL', 'GTL', 'GEL',
+ 'PLUS', 'MINUS', 'STAR', 'DIV', 'IDIV', 'MOD',
+ 'POW', 'SEMI'
+])
+def stringify(tokens):
+ def token_str(t):
+ if t.type in ('PERIOD', 'COMMA'):
+ return str(t) + ' '
+ if t.type in operators:
+ return ' ' + str(t) + ' '
+ return str(t)
+ return ''.join(map(token_str, tokens))
+
+# Yield the sequence of rules in [code].
+def split(code):
+ tokens = tokenize(code)
+ start = 0
+ for idx, token in enumerate(tokens):
+ if token.type == 'PERIOD' and idx - start > 1:
+ yield stringify(tokens[start:idx])
+ start = idx + 1
+
+# return a list of lines in 'code', and a list of rule indexes
+def decompose(code):
+ lines = []
+ rules = []
+ tokens = tokenize(code)
+ tokens.append(Token('EOF'))
+
+ line = []
+ parens = []
+ rule_start = 0
+ for t in tokens:
+ if t.type == 'SEMI':
+ if line != []:
+ lines.append(tuple(line))
+ line = []
+ lines.append((t,))
+ continue
+ if not parens:
+ if t.type in ('PERIOD', 'FROM', 'COMMA', 'EOF'):
+ if line != []:
+ lines.append(tuple(line))
+ line = []
+ if t.type in ('PERIOD', 'EOF') and rule_start < len(lines):
+ rules.append((rule_start, len(lines)))
+ rule_start = len(lines)
+ continue
+ if t.type in ('LPAREN', 'LBRACKET', 'LBRACE'):
+ parens.append(t.type)
+ elif parens:
+ if t.type == 'RPAREN' and parens[-1] == 'LPAREN':
+ parens.pop()
+ elif t.type == 'RBRACKET' and parens[-1] == 'LBRACKET':
+ parens.pop()
+ elif t.type == 'RBRACE' and parens[-1] == 'LBRACE':
+ parens.pop()
+ line.append(t)
+ return tuple(lines), tuple(rules)
+
+# pretty-print a list of rules
+def compose(lines, rules):
+ code = ''
+ for start, end in rules:
+ for i in range(start, end):
+ line = lines[i]
+ if i > start:
+ code += ' '
+ code += stringify(line)
+ if i == end-1:
+ code += '.\n'
+ elif i == start:
+ code += ' :-\n'
+ else:
+ if line and line[-1].type != 'SEMI' and lines[i+1][-1].type != 'SEMI':
+ code += ','
+ code += '\n'
+ return code.strip()
+
+# standardize variable names in order of appearance
+def rename_vars(tokens, names={}):
+ # copy names so we don't fuck it up
+ names = {k: v for k, v in names.items()}
+ next_id = len(names)
+ for i in range(len(tokens)):
+ if tokens[i].type == 'PERIOD':
+ names.clear()
+ next_id = 0
+ elif tokens[i] == Token('VARIABLE', '_'):
+ tokens[i] = Token('VARIABLE', 'A' + str(next_id))
+ next_id += 1
+ elif tokens[i].type == 'VARIABLE':
+ cur_name = tokens[i].val
+ if cur_name not in names:
+ names[cur_name] = next_id
+ next_id += 1
+ tokens[i] = Token('VARIABLE', 'A' + str(names[cur_name]))
+ return names
+
+# transformation = before → after; applied on line which is part of rule
+# return mapping from formal vars in before+after to actual vars in rule
+# line and rule should of course not be normalized
+def map_vars(before, after, line, rule):
+ mapping = {}
+ new_index = 0
+ for i in range(len(before)):
+ if line[i].type == 'VARIABLE':
+ formal_name = before[i].val
+ if line[i].val != '_':
+ actual_name = line[i].val
+ else:
+ actual_name = 'New'+str(new_index)
+ new_index += 1
+ mapping[formal_name] = actual_name
+
+ remaining_formal = [t.val for t in after if t.type == 'VARIABLE' and t.val not in mapping.keys()]
+ remaining_actual = [t.val for t in rule if t.type == 'VARIABLE' and t.val != '_' and t.val not in mapping.values()]
+
+ while len(remaining_actual) < len(remaining_formal):
+ remaining_actual.append('New'+str(new_index))
+ new_index += 1
+
+ for i, formal_name in enumerate(remaining_formal):
+ mapping[formal_name] = remaining_actual[i]
+
+ return mapping
+
+# Basic sanity check.
+if __name__ == '__main__':
+ print(compose(*decompose('dup([H|T], [H1|T1]) :- dup(T1, T2). ')))
+
+ rule = tokenize('dup([H|T], [H1|T1]) :- dup(T1, T2). ')
+ line = tokenize('dup([H|T], [H1|T1]) :-')
+ before = tokenize("dup([A0|A1], [A2|A3])")
+ after = tokenize("dup([A0|A1], [A5, A4|A3])")
+ var_names = rename_vars(before)
+ rename_vars(after, var_names)
+
+ mapping = map_vars(before, after, line, rule)
+ print(mapping)
diff --git a/monkey/util.py b/monkey/util.py
new file mode 100644
index 0000000..b8be2bb
--- /dev/null
+++ b/monkey/util.py
@@ -0,0 +1,81 @@
+#!/usr/bin/python3
+
+from collections import namedtuple
+from heapq import heappush, heappop
+import itertools
+
+# A simple priority queue based on the heapq class.
+class PQueue(object):
+ REMOVED = '<removed-task>' # placeholder for a removed task
+
+ def __init__(self):
+ self.pq = [] # list of entries arranged in a heap
+ self.entry_finder = {} # mapping of tasks to entries
+ self.counter = itertools.count() # unique sequence count
+ self.size = 0
+
+ def push(self, task, priority=0):
+ 'Add a new task or update the priority of an existing task'
+ if task in self.entry_finder:
+ self.remove(task)
+ else:
+ self.size += 1
+ entry = [priority, next(self.counter), task]
+ self.entry_finder[task] = entry
+ heappush(self.pq, entry)
+
+ def remove(self, task):
+ 'Mark an existing task as REMOVED. Raise KeyError if not found.'
+ entry = self.entry_finder.pop(task)
+ entry[-1] = self.REMOVED
+ self.size -= 1
+
+ def pop(self):
+ 'Remove and return the lowest priority task. Raise KeyError if empty.'
+ while self.pq:
+ priority, count, task = heappop(self.pq)
+ if task is not self.REMOVED:
+ del self.entry_finder[task]
+ self.size -= 1
+ return task
+ return None
+
+ def __len__(self):
+ return self.size
+
+# Stores a token's type and value, and optionally the position of the first
+# character in the lexed stream.
+class Token(namedtuple('Token', ['type', 'val', 'pos'])):
+ __slots__ = ()
+
+ # Custom constructor to support default parameters.
+ def __new__(cls, type, val='', pos=None):
+ return super(Token, cls).__new__(cls, type, val, pos)
+
+ def __str__(self):
+ return self.val
+
+ # Ignore position when comparing tokens. There is probably a cleaner way of
+ # doing these.
+ __eq__ = lambda x, y: x[0] == y[0] and x[1] == y[1]
+ __ne__ = lambda x, y: x[0] != y[0] or x[1] != y[1]
+ __lt__ = lambda x, y: tuple.__lt__(x[0:2], y[0:2])
+ __le__ = lambda x, y: tuple.__le__(x[0:2], y[0:2])
+ __ge__ = lambda x, y: tuple.__ge__(x[0:2], y[0:2])
+ __gt__ = lambda x, y: tuple.__gt__(x[0:2], y[0:2])
+
+ # Only hash token's value (we don't care about position, and types are
+ # determined by values).
+ def __hash__(self):
+ return hash(self[1])
+
+# Return [n]th line in [text].
+def get_line(text, n):
+ lines = text.split('\n')
+ if n >= 0 and n < len(lines):
+ return lines[n]
+ return None
+
+# Indent each line in [text] by [indent] spaces.
+def indent(text, indent=2):
+ return '\n'.join([' '*indent+line for line in text.split('\n')])