summaryrefslogtreecommitdiff
path: root/monkey/edits.py
diff options
context:
space:
mode:
Diffstat (limited to 'monkey/edits.py')
-rw-r--r--monkey/edits.py75
1 files changed, 29 insertions, 46 deletions
diff --git a/monkey/edits.py b/monkey/edits.py
index f34f4b4..c625946 100644
--- a/monkey/edits.py
+++ b/monkey/edits.py
@@ -132,30 +132,33 @@ def trace_graph(trace, debug=False):
return nodes, submissions, queries
-# Generate all interesting paths in the edit graph rooted at [root].
-def get_paths(root, path=None, done=None):
- if done is None:
- done = set()
-
- # Add [root] to [path] if it is the first node or different than previous.
- if not path:
- path = (root.data[2],)
- elif root.data[2] != path[-1]:
- path = path + (root.data[2],)
-
- # Return the current path if [root] is a leaf or an empty node.
- if len(path) > 1:
- if not root.eout or not root.data[2]:
- yield path
-
- # If [root] is an empty node, start a new path.
- if not root.data[2]:
- path = (root.data[2],)
- done.add(root)
-
- for node in root.eout:
- if node not in done:
- yield from get_paths(node, path, done)
+# Return a set of edits that appear in the trace_graph given by [nodes].
+def graph_edits(nodes):
+ edits = set()
+ for a in nodes:
+ a_data = a.data[2]
+ for b in a.eout:
+ b_data = b.data[2]
+
+ # Normalize a → b into start → end. Reuse variable names from a
+ # when normalizing b.
+ var_names = {}
+ start = normalized(a_data, var_names)
+ end = normalized(b_data, var_names)
+
+ if start == end:
+ continue
+
+ if not end and len(a.eout) > 1:
+ continue
+
+ # Disallow edits that insert a whole rule (a → …:-…).
+ # TODO improve trace_graph to handle this.
+ if 'FROM' in [t.type for t in end[:-1]]:
+ continue
+
+ edits.add((start, end))
+ return edits
# Build an edit graph for each trace and find "meaningful" (to be defined)
# edits. Return a dictionary of edits and their frequencies, and also
@@ -181,28 +184,8 @@ def get_edits_from_traces(traces):
code = stringify(rename_vars(tokenize(query)))
queries[code] += 1
- # Get edits.
- trace_edits = set()
- for path in get_paths(nodes[0]):
- for i in range(1, len(path)):
- # Normalize path[i-1] → path[i] into start → end. Reuse
- # variable names from start when normalizing end.
- var_names = {}
- start = normalized(path[i-1], var_names)
- end = normalized(path[i], var_names)
-
- # Disallow edits that insert a whole rule (a → … :- …).
- # TODO improve trace_graph to handle this.
- if 'FROM' in [t.type for t in end[:-1]]:
- continue
-
- # This should always succeed but check anyway.
- if start != end:
- edit = (start, end)
- trace_edits.add(edit)
- edits.update(trace_edits)
-
- # Update node counts.
+ # Update the edit and leaf/node counters.
+ edits.update(graph_edits(nodes))
n_leaf.update(set([normalized(n.data[2]) for n in nodes if n.data[2] and not n.eout]))
n_all.update(set([normalized(n.data[2]) for n in nodes if n.data[2]]))