From b85a6499d21beec2cb87830e63e6afef9569df1a Mon Sep 17 00:00:00 2001 From: Timotej Lazar Date: Sun, 8 Feb 2015 05:11:30 +0100 Subject: Simplify get_edits_from_traces --- monkey/edits.py | 76 +++++++++++++++++++++++++++++++++----------------------- monkey/monkey.py | 15 ++++++----- monkey/util.py | 9 +++++++ 3 files changed, 61 insertions(+), 39 deletions(-) diff --git a/monkey/edits.py b/monkey/edits.py index a920bf5..480af5f 100644 --- a/monkey/edits.py +++ b/monkey/edits.py @@ -1,11 +1,12 @@ #!/usr/bin/python3 import collections +import math from .action import expand, parse from .graph import Node from prolog.util import rename_vars, stringify, tokenize -from .util import get_line +from .util import get_line, avg, logistic # A line edit is a contiguous sequences of actions within a single line. This # function takes a sequence of actions and builds a directed acyclic graph @@ -146,25 +147,31 @@ def get_paths(root, path=None, done=None): # edits. Return a dictionary of edits and their frequencies, and also # submissions and queries in [traces]. def get_edits_from_traces(traces): - # Helper function to remove trailing punctuation from lines. - def remove_punct(line): + # Helper function to remove trailing punctuation from lines and rename + # variables to A1,A2,A3,… (potentially using [var_names]). Return a tuple. + def normalize(line, var_names=None): + # Remove trailing punctuation. i = len(line) while i > 0: if line[i-1].type not in ('COMMA', 'PERIOD', 'SEMI'): break i -= 1 - return line[:i] + return tuple(rename_vars(line[:i], var_names)) # Return values: counts for observed edits, lines, submissions and queries. edits = collections.Counter() - lines = collections.Counter() submissions = collections.Counter() queries = collections.Counter() + # Counts of traces where each line appears as a leaf / any node. + n_leaf = collections.Counter() + n_all = collections.Counter() + for trace in traces: try: actions = parse(trace) except: + # Only a few traces fail to parse, so just ignore them. continue nodes, trace_submissions, trace_queries = edit_graph(actions) @@ -178,41 +185,48 @@ def get_edits_from_traces(traces): queries[code] += 1 # Get edits. - seen_edits = set() + trace_edits = set() for path in get_paths(nodes[0]): - for i in range(len(path)): + for i in range(1, len(path)): + # Normalize path[i-1] → path[i] into start → end. Reuse + # variable names from start when normalizing end. var_names = {} - start = tuple(rename_vars(remove_punct(path[i]), var_names)) - for j in range(len(path[i+1:])): - var_names_copy = {k: v for k, v in var_names.items()} - end = tuple(rename_vars(remove_punct(path[i+1+j]), var_names_copy)) - if start == end: - continue + start = normalize(path[i-1], var_names) + end = normalize(path[i], var_names) + # This should always succeed but check anyway. + if start != end: edit = (start, end) - if edit not in seen_edits: - seen_edits.add(edit) - edits[edit] += 1 - lines[start] += 1 + trace_edits.add(edit) + edits.update(trace_edits) + + # Update node counts. + n_leaf.update(set([normalize(n.data[2]) for n in nodes if n.data[2] and not n.eout])) + n_all.update(set([normalize(n.data[2]) for n in nodes if n.data[2]])) - # Discard rarely occurring edits. XXX only for testing + # Discard edits that only occur in one trace. singletons = [edit for edit in edits if edits[edit] < 2] for edit in singletons: - lines[edit[0]] -= edits[edit] del edits[edit] - # Get the probability of each edit given its "before" or "after" part. - max_insert_count = max([count for (before, after), count in edits.items() if not before]) - for before, after in edits: - if before: - edits[(before, after)] /= max(lines[before], 1) - else: - edits[(before, after)] /= max_insert_count - - # Normalize line frequencies. - if len(lines) > 0: - lines_max = max(max(lines.values()), 1) - lines = {line: count/lines_max for line, count in lines.items()} + # Find the probability of each edit a → b. + for (a, b), count in edits.items(): + p = 1.0 + if a: + p *= 1 - (n_leaf[a] / (n_all[a]+1)) + if b: + b_normal = normalize(b) + p *= n_leaf[b_normal] / (n_all[b_normal]+1) + if a and b: + p = math.sqrt(p) + edits[(a, b)] = p + + # Tweak the edit distribution to improve search. + avg_p = avg(edits.values()) + for edit, p in edits.items(): + edits[edit] = logistic(p, k=3, x_0=avg_p) + + lines = dict(n_leaf) return edits, lines, submissions, queries diff --git a/monkey/monkey.py b/monkey/monkey.py index e185630..e90d70a 100755 --- a/monkey/monkey.py +++ b/monkey/monkey.py @@ -59,7 +59,7 @@ def fix(name, code, edits, program_lines, aux_code='', timeout=30, debug=False): if new_end > new_start: new_rules.append((new_start, new_end)) new_step = ('remove_line', line_idx, (tuple(line), ())) - new_cost = removes[line_normal] if line_normal in removes.keys() else 0.9 + new_cost = removes.get(line_normal, 1.0) * 0.99 yield (new_lines, new_rules, new_step, new_cost) @@ -74,8 +74,9 @@ def fix(name, code, edits, program_lines, aux_code='', timeout=30, debug=False): new_rules.append((old_start + (0 if old_start < end else 1), old_end + (0 if old_end < end else 1))) new_step = ('add_subgoal', end, ((), after_real)) + new_cost = cost * 0.7 - yield (new_lines, new_rules, new_step, cost) + yield (new_lines, new_rules, new_step, new_cost) # Add a new fact at the end. if len(rules) < 2: @@ -83,8 +84,9 @@ def fix(name, code, edits, program_lines, aux_code='', timeout=30, debug=False): new_lines = lines + (after,) new_rules = rules + (((len(lines), len(lines)+1)),) new_step = ('add_rule', len(new_lines)-1, (tuple(), tuple(after))) + new_cost = cost * 0.7 - yield (new_lines, new_rules, new_step, cost) + yield (new_lines, new_rules, new_step, new_cost) # Return a cleaned-up list of steps: # - multiple line edits in a single line are merged into one @@ -103,9 +105,6 @@ def fix(name, code, edits, program_lines, aux_code='', timeout=30, debug=False): i += 1 return new_steps - # Prevents useless edits with cost=1.0. - step_cost = 0.99 - # Main loop: best-first search through generated programs. todo = PQueue() # priority queue of candidate solutions done = set() # programs we have already visited @@ -146,8 +145,8 @@ def fix(name, code, edits, program_lines, aux_code='', timeout=30, debug=False): # Otherwise generate new solutions. prev_step = path[-1] if path else None for new_lines, new_rules, new_step, new_cost in step(lines, rules, prev_step): - new_path_cost = path_cost * new_cost * step_cost - if new_path_cost < 0.005: + new_path_cost = path_cost * new_cost + if new_path_cost < 0.4: continue new_path = path + (new_step,) todo.push(((tuple(new_lines), tuple(new_rules)), new_path, new_path_cost), -new_path_cost) diff --git a/monkey/util.py b/monkey/util.py index 6d57d29..9648926 100644 --- a/monkey/util.py +++ b/monkey/util.py @@ -2,6 +2,7 @@ from heapq import heappush, heappop import itertools +import math # A simple priority queue based on the heapq class. class PQueue(object): @@ -52,3 +53,11 @@ def get_line(text, n): # Indent each line in [text] by [indent] spaces. def indent(text, indent=2): return '\n'.join([' '*indent+line for line in text.split('\n')]) + +# Average in a list. +def avg(l): + return sum(l)/len(l) + +# Logistic function. +def logistic(x, L=1, k=1, x_0=0): + return L/(1+math.exp(-k*(x-x_0))) -- cgit v1.2.1