summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--monkey/edits.py23
-rwxr-xr-xmonkey/monkey.py2
-rw-r--r--prolog/util.py11
3 files changed, 18 insertions, 18 deletions
diff --git a/monkey/edits.py b/monkey/edits.py
index 8747a6e..15734c4 100644
--- a/monkey/edits.py
+++ b/monkey/edits.py
@@ -5,7 +5,7 @@ import math
from .action import expand, parse
from .graph import Node
-from prolog.util import rename_vars, stringify, tokenize
+from prolog.util import normalized, rename_vars, stringify, tokenize
from .util import get_line, avg, logistic
# A line edit is a contiguous sequences of actions within a single line. This
@@ -154,17 +154,6 @@ def get_paths(root, path=None, done=None):
# edits. Return a dictionary of edits and their frequencies, and also
# submissions and queries in [traces].
def get_edits_from_traces(traces):
- # Helper function to remove trailing punctuation from lines and rename
- # variables to A1,A2,A3,… (potentially using [var_names]). Return a tuple.
- def normalize(line, var_names=None):
- # Remove trailing punctuation.
- i = len(line)
- while i > 0:
- if line[i-1].type not in ('COMMA', 'PERIOD', 'SEMI'):
- break
- i -= 1
- return tuple(rename_vars(line[:i], var_names))
-
# Return values: counts for observed edits, lines, submissions and queries.
edits = collections.Counter()
submissions = collections.Counter()
@@ -198,8 +187,8 @@ def get_edits_from_traces(traces):
# Normalize path[i-1] → path[i] into start → end. Reuse
# variable names from start when normalizing end.
var_names = {}
- start = normalize(path[i-1], var_names)
- end = normalize(path[i], var_names)
+ start = normalized(path[i-1], var_names)
+ end = normalized(path[i], var_names)
# Disallow edits that insert a whole rule (a → … :- …).
# TODO improve edit_graph to handle this.
@@ -213,8 +202,8 @@ def get_edits_from_traces(traces):
edits.update(trace_edits)
# Update node counts.
- n_leaf.update(set([normalize(n.data[2]) for n in nodes if n.data[2] and not n.eout]))
- n_all.update(set([normalize(n.data[2]) for n in nodes if n.data[2]]))
+ n_leaf.update(set([normalized(n.data[2]) for n in nodes if n.data[2] and not n.eout]))
+ n_all.update(set([normalized(n.data[2]) for n in nodes if n.data[2]]))
# Discard edits that only occur in one trace.
singletons = [edit for edit in edits if edits[edit] < 2]
@@ -227,7 +216,7 @@ def get_edits_from_traces(traces):
if a:
p *= 1 - (n_leaf[a] / (n_all[a]+1))
if b:
- b_normal = normalize(b)
+ b_normal = normalized(b)
p *= n_leaf[b_normal] / (n_all[b_normal]+1)
if a and b:
p = math.sqrt(p)
diff --git a/monkey/monkey.py b/monkey/monkey.py
index 0a934bc..080b317 100755
--- a/monkey/monkey.py
+++ b/monkey/monkey.py
@@ -5,7 +5,7 @@ import time
from .edits import classify_edits
from prolog.engine import test
-from prolog.util import Token, compose, decompose, map_vars, rename_vars, stringify
+from prolog.util import Token, compose, decompose, map_vars, normalized, rename_vars, stringify
from .util import PQueue
# Starting from [code], find a sequence of edits that transforms it into a
diff --git a/prolog/util.py b/prolog/util.py
index e5a93e2..48e3345 100644
--- a/prolog/util.py
+++ b/prolog/util.py
@@ -156,6 +156,17 @@ def rename_vars(tokens, names=None):
tokens[i] = Token('VARIABLE', names[cur_name])
return tokens
+# Helper function to remove trailing punctuation from lines and rename
+# variables to A1,A2,A3,… (potentially using [var_names]). Return a tuple.
+def normalized(line, var_names=None):
+ # Remove trailing punctuation.
+ i = len(line)
+ while i > 0:
+ if line[i-1].type not in ('COMMA', 'PERIOD', 'SEMI'):
+ break
+ i -= 1
+ return tuple(rename_vars(line[:i], var_names))
+
# transformation = before → after; applied on line which is part of rule
# return mapping from formal vars in before+after to actual vars in rule
# line and rule should of course not be normalized