summaryrefslogtreecommitdiff
path: root/monkey/edits.py
diff options
context:
space:
mode:
Diffstat (limited to 'monkey/edits.py')
-rw-r--r--monkey/edits.py23
1 files changed, 6 insertions, 17 deletions
diff --git a/monkey/edits.py b/monkey/edits.py
index 8747a6e..15734c4 100644
--- a/monkey/edits.py
+++ b/monkey/edits.py
@@ -5,7 +5,7 @@ import math
from .action import expand, parse
from .graph import Node
-from prolog.util import rename_vars, stringify, tokenize
+from prolog.util import normalized, rename_vars, stringify, tokenize
from .util import get_line, avg, logistic
# A line edit is a contiguous sequences of actions within a single line. This
@@ -154,17 +154,6 @@ def get_paths(root, path=None, done=None):
# edits. Return a dictionary of edits and their frequencies, and also
# submissions and queries in [traces].
def get_edits_from_traces(traces):
- # Helper function to remove trailing punctuation from lines and rename
- # variables to A1,A2,A3,… (potentially using [var_names]). Return a tuple.
- def normalize(line, var_names=None):
- # Remove trailing punctuation.
- i = len(line)
- while i > 0:
- if line[i-1].type not in ('COMMA', 'PERIOD', 'SEMI'):
- break
- i -= 1
- return tuple(rename_vars(line[:i], var_names))
-
# Return values: counts for observed edits, lines, submissions and queries.
edits = collections.Counter()
submissions = collections.Counter()
@@ -198,8 +187,8 @@ def get_edits_from_traces(traces):
# Normalize path[i-1] → path[i] into start → end. Reuse
# variable names from start when normalizing end.
var_names = {}
- start = normalize(path[i-1], var_names)
- end = normalize(path[i], var_names)
+ start = normalized(path[i-1], var_names)
+ end = normalized(path[i], var_names)
# Disallow edits that insert a whole rule (a → … :- …).
# TODO improve edit_graph to handle this.
@@ -213,8 +202,8 @@ def get_edits_from_traces(traces):
edits.update(trace_edits)
# Update node counts.
- n_leaf.update(set([normalize(n.data[2]) for n in nodes if n.data[2] and not n.eout]))
- n_all.update(set([normalize(n.data[2]) for n in nodes if n.data[2]]))
+ n_leaf.update(set([normalized(n.data[2]) for n in nodes if n.data[2] and not n.eout]))
+ n_all.update(set([normalized(n.data[2]) for n in nodes if n.data[2]]))
# Discard edits that only occur in one trace.
singletons = [edit for edit in edits if edits[edit] < 2]
@@ -227,7 +216,7 @@ def get_edits_from_traces(traces):
if a:
p *= 1 - (n_leaf[a] / (n_all[a]+1))
if b:
- b_normal = normalize(b)
+ b_normal = normalized(b)
p *= n_leaf[b_normal] / (n_all[b_normal]+1)
if a and b:
p = math.sqrt(p)