summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--monkey/patterns.py61
1 files changed, 31 insertions, 30 deletions
diff --git a/monkey/patterns.py b/monkey/patterns.py
index 90709b9..e22c39b 100644
--- a/monkey/patterns.py
+++ b/monkey/patterns.py
@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import collections
-from itertools import combinations
+from itertools import chain, combinations, product
import pickle
import random
import sys
@@ -110,45 +110,46 @@ def get_patterns(tree):
# get patterns separately for each clause
for clause in tree:
+ # collect variable nodes in this clause
variables = collections.defaultdict(list)
- def walk(node):
- if isinstance(node, Tree):
- if node.label() == 'variable':
- name = node[0].val
- variables[name].append(node)
- else:
- for child in node:
- walk(child)
- walk(clause)
+ for node in clause.subtrees():
+ if isinstance(node, Tree) and node.label() == 'variable':
+ name = node[0].val
+ variables[name].append(node)
- # connecting pairs of nodes with same variable
+ # yield patterns for singleton variables
for var, nodes in variables.items():
- for selected in combinations(nodes, 2):
- pat = pattern(clause, selected)
+ if len(nodes) == 1:
+ pat = pattern(clause, nodes)
if pat:
- yield pat, selected
+ yield pat, nodes
- # add singletons
+ # yield patterns for variable-variable pairs (within a clause)
for var, nodes in variables.items():
- if len(nodes) == 1:
- pat = pattern(clause, nodes)
+ for selected in combinations(nodes, 2):
+ pat = pattern(clause, selected)
if pat:
- yield pat, nodes
+ yield pat, selected
- # add patterns for literal/variable pairs
- literals = [node for node in clause.subtrees() if node.label() == 'literal']
- for literal in literals:
- top = literal
- while top != clause and top.parent().label() in {'compound', 'binop', 'unop', 'args', 'list'}:
- top = top.parent()
- variables = [node for node in top.subtrees() if node.label() == 'variable']
- for var in variables:
- pat = pattern(clause, [var, literal])
- if pat:
- yield pat, [var, literal]
+ # yield patterns for variable-literal / literal-literal pairs
+ # (only within a topmost compound / binop / unop)
+ def patterns_with_literals(node):
+ if not isinstance(node, Tree):
+ return
+ if node.label() in {'compound', 'binop', 'unop'}:
+ vars = [n for n in node.subtrees() if n.label() == 'variable']
+ lits = [n for n in node.subtrees() if n.label() == 'literal']
+ for selected in chain(combinations(lits, 2), product(lits, vars)):
+ pat = pattern(clause, selected)
+ if pat:
+ yield pat, selected
+ else:
+ for child in node:
+ yield from patterns_with_literals(child)
+ yield from patterns_with_literals(clause)
# Extract edits and other data from existing traces for each problem.
-# Run with: python3 -m monkey.patterns <problem ID> <solutions.pickle>
+# Run with: python3 -m monkey.patterns <problem ID> <predicate name> <solutions.pickle>
if __name__ == '__main__':
pid = int(sys.argv[1])
name = sys.argv[2]