summaryrefslogtreecommitdiff
path: root/monkey/patterns.py
diff options
context:
space:
mode:
Diffstat (limited to 'monkey/patterns.py')
-rw-r--r--monkey/patterns.py149
1 files changed, 149 insertions, 0 deletions
diff --git a/monkey/patterns.py b/monkey/patterns.py
new file mode 100644
index 0000000..7dfa96f
--- /dev/null
+++ b/monkey/patterns.py
@@ -0,0 +1,149 @@
+# CodeQ: an online programming tutor.
+# Copyright (C) 2016,2017 UL FRI
+#
+# This program is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Affero General Public License as published by the Free
+# Software Foundation, either version 3 of the License, or (at your option) any
+# later version.
+#
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+import collections
+from itertools import chain, combinations, product
+
+from nltk import Tree
+
+from prolog.util import parse as prolog_parse
+
+# construct pattern to match the structure of nodes given by [include],
+# supports variables and literals
+def pattern(node, include):
+ if not isinstance(node, Tree):
+ return None
+
+ label = node.label()
+ if any(n is node for n in include):
+ if label == 'literal':
+ return '"{}"'.format(node[0].val)
+ if label == 'variable':
+ return '{}'.format(label)
+ return None
+ if label == 'functor':
+ return '({} "{}")'.format(label, node[0].val)
+
+ subpats = [pattern(child, include) for child in node]
+ pat = None
+ if any(subpats):
+ if label == 'and':
+ if subpats[1]:
+ pat = subpats[1]
+ if subpats[0]:
+ if pat:
+ pat = subpats[0] + ' ' + pat
+ else:
+ pat = subpats[0]
+ elif label == 'args':
+ pat = label
+ for i, subpat in enumerate(subpats):
+ if subpat:
+ pat += ' {}'.format(subpat)
+ pat = '(' + pat + ')'
+ elif label == 'unop':
+ pat = '(' + label + ' ' + node[0].val + ' ' + subpats[1] + ')'
+ elif label == 'binop':
+ pat = label
+ if subpats[0]:
+ pat += ' {}'.format(subpats[0])
+ pat += ' "{}"'.format(node[1].val)
+ if subpats[2]:
+ pat += ' {}'.format(subpats[2])
+ pat = '(' + pat + ')'
+ elif label == 'clause':
+ pat = label
+ for i, subpat in enumerate(subpats):
+ if subpat:
+ pat += ' {}'.format(subpats[i])
+ return '(' + pat + ')'
+ elif label == 'compound':
+ if len(subpats) > 1 and subpats[1]:
+ pat = label
+ for i, subpat in enumerate(subpats):
+ pat += ' {}'.format(subpat)
+ pat = '(' + pat + ')'
+ else:
+ return None
+ elif label == 'head':
+ pat = label
+ pat += ' {}'.format(subpats[0])
+ pat = '(' + pat + ')'
+ elif label == 'list':
+ pat = 'list '
+ if subpats[0]:
+ pat += '(h {})'.format(subpats[0])
+ if subpats[0] and subpats[1]:
+ pat += ' '
+ if subpats[1]:
+ pat += '(t {})'.format(subpats[1])
+ pat = '(' + pat + ')'
+ if not pat:
+ for s in subpats:
+ if s:
+ pat = s
+ break
+ return pat
+
+def get_patterns(tree):
+ if isinstance(tree, str):
+ tree = prolog_parse(tree)
+ if tree is None:
+ return
+
+ # get patterns separately for each clause
+ for clause in tree:
+ # collect variable nodes in this clause
+ variables = collections.defaultdict(list)
+ for node in clause.subtrees():
+ if isinstance(node, Tree) and node.label() == 'variable':
+ name = node[0].val
+ variables[name].append(node)
+
+ # yield patterns for singleton variables
+ for var, nodes in variables.items():
+ if len(nodes) == 1:
+ pat = pattern(clause, nodes)
+ if pat:
+ yield pat, nodes
+
+ # yield patterns for variable-variable pairs (within a clause)
+ for var, nodes in variables.items():
+ for selected in combinations(nodes, 2):
+ pat = pattern(clause, selected)
+ if pat:
+ yield pat, selected
+
+ # yield patterns for variable-literal / literal-literal pairs
+ # yield patterns for singleton literals
+ # (only within a topmost compound / binop / unop)
+ def patterns_with_literals(node):
+ if not isinstance(node, Tree):
+ return
+ if node.label() in {'compound', 'binop', 'unop'}:
+ vars = [n for n in node.subtrees() if n.label() == 'variable']
+ lits = [n for n in node.subtrees() if n.label() == 'literal']
+ for selected in chain(
+ combinations(lits, 1),
+ combinations(lits, 2),
+ product(lits, vars)):
+ pat = pattern(clause, selected)
+ if pat:
+ yield pat, selected
+ else:
+ for child in node:
+ yield from patterns_with_literals(child)
+ yield from patterns_with_literals(clause)