summaryrefslogtreecommitdiff
path: root/monkey
diff options
context:
space:
mode:
authorTimotej Lazar <timotej.lazar@araneo.org>2015-02-08 05:11:30 +0100
committerAleš Smodiš <aless@guru.si>2015-08-11 14:26:02 +0200
commitb85a6499d21beec2cb87830e63e6afef9569df1a (patch)
tree4486f48751f8144e25adb162bd366dcdf3839830 /monkey
parent25dfa090db39040a93e8186d228f70f0e0bd5f23 (diff)
Simplify get_edits_from_traces
Diffstat (limited to 'monkey')
-rw-r--r--monkey/edits.py76
-rwxr-xr-xmonkey/monkey.py15
-rw-r--r--monkey/util.py9
3 files changed, 61 insertions, 39 deletions
diff --git a/monkey/edits.py b/monkey/edits.py
index a920bf5..480af5f 100644
--- a/monkey/edits.py
+++ b/monkey/edits.py
@@ -1,11 +1,12 @@
#!/usr/bin/python3
import collections
+import math
from .action import expand, parse
from .graph import Node
from prolog.util import rename_vars, stringify, tokenize
-from .util import get_line
+from .util import get_line, avg, logistic
# A line edit is a contiguous sequences of actions within a single line. This
# function takes a sequence of actions and builds a directed acyclic graph
@@ -146,25 +147,31 @@ def get_paths(root, path=None, done=None):
# edits. Return a dictionary of edits and their frequencies, and also
# submissions and queries in [traces].
def get_edits_from_traces(traces):
- # Helper function to remove trailing punctuation from lines.
- def remove_punct(line):
+ # Helper function to remove trailing punctuation from lines and rename
+ # variables to A1,A2,A3,… (potentially using [var_names]). Return a tuple.
+ def normalize(line, var_names=None):
+ # Remove trailing punctuation.
i = len(line)
while i > 0:
if line[i-1].type not in ('COMMA', 'PERIOD', 'SEMI'):
break
i -= 1
- return line[:i]
+ return tuple(rename_vars(line[:i], var_names))
# Return values: counts for observed edits, lines, submissions and queries.
edits = collections.Counter()
- lines = collections.Counter()
submissions = collections.Counter()
queries = collections.Counter()
+ # Counts of traces where each line appears as a leaf / any node.
+ n_leaf = collections.Counter()
+ n_all = collections.Counter()
+
for trace in traces:
try:
actions = parse(trace)
except:
+ # Only a few traces fail to parse, so just ignore them.
continue
nodes, trace_submissions, trace_queries = edit_graph(actions)
@@ -178,41 +185,48 @@ def get_edits_from_traces(traces):
queries[code] += 1
# Get edits.
- seen_edits = set()
+ trace_edits = set()
for path in get_paths(nodes[0]):
- for i in range(len(path)):
+ for i in range(1, len(path)):
+ # Normalize path[i-1] → path[i] into start → end. Reuse
+ # variable names from start when normalizing end.
var_names = {}
- start = tuple(rename_vars(remove_punct(path[i]), var_names))
- for j in range(len(path[i+1:])):
- var_names_copy = {k: v for k, v in var_names.items()}
- end = tuple(rename_vars(remove_punct(path[i+1+j]), var_names_copy))
- if start == end:
- continue
+ start = normalize(path[i-1], var_names)
+ end = normalize(path[i], var_names)
+ # This should always succeed but check anyway.
+ if start != end:
edit = (start, end)
- if edit not in seen_edits:
- seen_edits.add(edit)
- edits[edit] += 1
- lines[start] += 1
+ trace_edits.add(edit)
+ edits.update(trace_edits)
+
+ # Update node counts.
+ n_leaf.update(set([normalize(n.data[2]) for n in nodes if n.data[2] and not n.eout]))
+ n_all.update(set([normalize(n.data[2]) for n in nodes if n.data[2]]))
- # Discard rarely occurring edits. XXX only for testing
+ # Discard edits that only occur in one trace.
singletons = [edit for edit in edits if edits[edit] < 2]
for edit in singletons:
- lines[edit[0]] -= edits[edit]
del edits[edit]
- # Get the probability of each edit given its "before" or "after" part.
- max_insert_count = max([count for (before, after), count in edits.items() if not before])
- for before, after in edits:
- if before:
- edits[(before, after)] /= max(lines[before], 1)
- else:
- edits[(before, after)] /= max_insert_count
-
- # Normalize line frequencies.
- if len(lines) > 0:
- lines_max = max(max(lines.values()), 1)
- lines = {line: count/lines_max for line, count in lines.items()}
+ # Find the probability of each edit a → b.
+ for (a, b), count in edits.items():
+ p = 1.0
+ if a:
+ p *= 1 - (n_leaf[a] / (n_all[a]+1))
+ if b:
+ b_normal = normalize(b)
+ p *= n_leaf[b_normal] / (n_all[b_normal]+1)
+ if a and b:
+ p = math.sqrt(p)
+ edits[(a, b)] = p
+
+ # Tweak the edit distribution to improve search.
+ avg_p = avg(edits.values())
+ for edit, p in edits.items():
+ edits[edit] = logistic(p, k=3, x_0=avg_p)
+
+ lines = dict(n_leaf)
return edits, lines, submissions, queries
diff --git a/monkey/monkey.py b/monkey/monkey.py
index e185630..e90d70a 100755
--- a/monkey/monkey.py
+++ b/monkey/monkey.py
@@ -59,7 +59,7 @@ def fix(name, code, edits, program_lines, aux_code='', timeout=30, debug=False):
if new_end > new_start:
new_rules.append((new_start, new_end))
new_step = ('remove_line', line_idx, (tuple(line), ()))
- new_cost = removes[line_normal] if line_normal in removes.keys() else 0.9
+ new_cost = removes.get(line_normal, 1.0) * 0.99
yield (new_lines, new_rules, new_step, new_cost)
@@ -74,8 +74,9 @@ def fix(name, code, edits, program_lines, aux_code='', timeout=30, debug=False):
new_rules.append((old_start + (0 if old_start < end else 1),
old_end + (0 if old_end < end else 1)))
new_step = ('add_subgoal', end, ((), after_real))
+ new_cost = cost * 0.7
- yield (new_lines, new_rules, new_step, cost)
+ yield (new_lines, new_rules, new_step, new_cost)
# Add a new fact at the end.
if len(rules) < 2:
@@ -83,8 +84,9 @@ def fix(name, code, edits, program_lines, aux_code='', timeout=30, debug=False):
new_lines = lines + (after,)
new_rules = rules + (((len(lines), len(lines)+1)),)
new_step = ('add_rule', len(new_lines)-1, (tuple(), tuple(after)))
+ new_cost = cost * 0.7
- yield (new_lines, new_rules, new_step, cost)
+ yield (new_lines, new_rules, new_step, new_cost)
# Return a cleaned-up list of steps:
# - multiple line edits in a single line are merged into one
@@ -103,9 +105,6 @@ def fix(name, code, edits, program_lines, aux_code='', timeout=30, debug=False):
i += 1
return new_steps
- # Prevents useless edits with cost=1.0.
- step_cost = 0.99
-
# Main loop: best-first search through generated programs.
todo = PQueue() # priority queue of candidate solutions
done = set() # programs we have already visited
@@ -146,8 +145,8 @@ def fix(name, code, edits, program_lines, aux_code='', timeout=30, debug=False):
# Otherwise generate new solutions.
prev_step = path[-1] if path else None
for new_lines, new_rules, new_step, new_cost in step(lines, rules, prev_step):
- new_path_cost = path_cost * new_cost * step_cost
- if new_path_cost < 0.005:
+ new_path_cost = path_cost * new_cost
+ if new_path_cost < 0.4:
continue
new_path = path + (new_step,)
todo.push(((tuple(new_lines), tuple(new_rules)), new_path, new_path_cost), -new_path_cost)
diff --git a/monkey/util.py b/monkey/util.py
index 6d57d29..9648926 100644
--- a/monkey/util.py
+++ b/monkey/util.py
@@ -2,6 +2,7 @@
from heapq import heappush, heappop
import itertools
+import math
# A simple priority queue based on the heapq class.
class PQueue(object):
@@ -52,3 +53,11 @@ def get_line(text, n):
# Indent each line in [text] by [indent] spaces.
def indent(text, indent=2):
return '\n'.join([' '*indent+line for line in text.split('\n')])
+
+# Average in a list.
+def avg(l):
+ return sum(l)/len(l)
+
+# Logistic function.
+def logistic(x, L=1, k=1, x_0=0):
+ return L/(1+math.exp(-k*(x-x_0)))