diff options
Diffstat (limited to 'monkey/edits.py')
-rw-r--r-- | monkey/edits.py | 23 |
1 files changed, 6 insertions, 17 deletions
diff --git a/monkey/edits.py b/monkey/edits.py index 8747a6e..15734c4 100644 --- a/monkey/edits.py +++ b/monkey/edits.py @@ -5,7 +5,7 @@ import math from .action import expand, parse from .graph import Node -from prolog.util import rename_vars, stringify, tokenize +from prolog.util import normalized, rename_vars, stringify, tokenize from .util import get_line, avg, logistic # A line edit is a contiguous sequences of actions within a single line. This @@ -154,17 +154,6 @@ def get_paths(root, path=None, done=None): # edits. Return a dictionary of edits and their frequencies, and also # submissions and queries in [traces]. def get_edits_from_traces(traces): - # Helper function to remove trailing punctuation from lines and rename - # variables to A1,A2,A3,… (potentially using [var_names]). Return a tuple. - def normalize(line, var_names=None): - # Remove trailing punctuation. - i = len(line) - while i > 0: - if line[i-1].type not in ('COMMA', 'PERIOD', 'SEMI'): - break - i -= 1 - return tuple(rename_vars(line[:i], var_names)) - # Return values: counts for observed edits, lines, submissions and queries. edits = collections.Counter() submissions = collections.Counter() @@ -198,8 +187,8 @@ def get_edits_from_traces(traces): # Normalize path[i-1] → path[i] into start → end. Reuse # variable names from start when normalizing end. var_names = {} - start = normalize(path[i-1], var_names) - end = normalize(path[i], var_names) + start = normalized(path[i-1], var_names) + end = normalized(path[i], var_names) # Disallow edits that insert a whole rule (a → … :- …). # TODO improve edit_graph to handle this. @@ -213,8 +202,8 @@ def get_edits_from_traces(traces): edits.update(trace_edits) # Update node counts. - n_leaf.update(set([normalize(n.data[2]) for n in nodes if n.data[2] and not n.eout])) - n_all.update(set([normalize(n.data[2]) for n in nodes if n.data[2]])) + n_leaf.update(set([normalized(n.data[2]) for n in nodes if n.data[2] and not n.eout])) + n_all.update(set([normalized(n.data[2]) for n in nodes if n.data[2]])) # Discard edits that only occur in one trace. singletons = [edit for edit in edits if edits[edit] < 2] @@ -227,7 +216,7 @@ def get_edits_from_traces(traces): if a: p *= 1 - (n_leaf[a] / (n_all[a]+1)) if b: - b_normal = normalize(b) + b_normal = normalized(b) p *= n_leaf[b_normal] / (n_all[b_normal]+1) if a and b: p = math.sqrt(p) |