summaryrefslogtreecommitdiff
path: root/prolog/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'prolog/util.py')
-rw-r--r--prolog/util.py179
1 files changed, 179 insertions, 0 deletions
diff --git a/prolog/util.py b/prolog/util.py
new file mode 100644
index 0000000..7fb81e3
--- /dev/null
+++ b/prolog/util.py
@@ -0,0 +1,179 @@
+#!/usr/bin/python3
+
+from collections import namedtuple
+
+from .lexer import lexer, operators
+
+# Stores a token's type and value, and optionally the position of the first
+# character in the lexed stream.
+class Token(namedtuple('Token', ['type', 'val', 'pos'])):
+ __slots__ = ()
+
+ # Custom constructor to support default parameters.
+ def __new__(cls, type, val='', pos=None):
+ return super(Token, cls).__new__(cls, type, val, pos)
+
+ def __str__(self):
+ return self.val
+
+ # Ignore position when comparing tokens. There is probably a cleaner way of
+ # doing these.
+ __eq__ = lambda x, y: x[0] == y[0] and x[1] == y[1]
+ __ne__ = lambda x, y: x[0] != y[0] or x[1] != y[1]
+ __lt__ = lambda x, y: tuple.__lt__(x[0:2], y[0:2])
+ __le__ = lambda x, y: tuple.__le__(x[0:2], y[0:2])
+ __ge__ = lambda x, y: tuple.__ge__(x[0:2], y[0:2])
+ __gt__ = lambda x, y: tuple.__gt__(x[0:2], y[0:2])
+
+ # Only hash token's value (we don't care about position, and types are
+ # determined by values).
+ def __hash__(self):
+ return hash(self[1])
+
+# Return a list of tokens in [text].
+def tokenize(text):
+ lexer.input(text)
+ return [Token(t.type, t.value, t.lexpos) for t in lexer]
+
+# Return a one-line string representation of [tokens].
+def stringify(tokens):
+ def token_str(t):
+ if t.type in ('PERIOD', 'COMMA'):
+ return str(t) + ' '
+ if t.type in operators.values():
+ return ' ' + str(t) + ' '
+ return str(t)
+ return ''.join(map(token_str, tokens))
+
+# Yield the sequence of rules in [code].
+def split(code):
+ tokens = tokenize(code)
+ start = 0
+ for idx, token in enumerate(tokens):
+ if token.type == 'PERIOD' and idx - start > 1:
+ yield stringify(tokens[start:idx])
+ start = idx + 1
+
+# Return a list of lines in [code] and a list of rule ranges.
+def decompose(code):
+ lines = []
+ rules = []
+ tokens = tokenize(code)
+ tokens.append(Token('EOF'))
+
+ line = []
+ parens = []
+ rule_start = 0
+ for t in tokens:
+ if t.type == 'SEMI':
+ if line != []:
+ lines.append(tuple(line))
+ line = []
+ lines.append((t,))
+ continue
+ if not parens:
+ if t.type in ('PERIOD', 'FROM', 'COMMA', 'EOF'):
+ if line != []:
+ lines.append(tuple(line))
+ line = []
+ if t.type in ('PERIOD', 'EOF') and rule_start < len(lines):
+ rules.append((rule_start, len(lines)))
+ rule_start = len(lines)
+ continue
+ if t.type in ('LPAREN', 'LBRACKET', 'LBRACE'):
+ parens.append(t.type)
+ elif parens:
+ if t.type == 'RPAREN' and parens[-1] == 'LPAREN':
+ parens.pop()
+ elif t.type == 'RBRACKET' and parens[-1] == 'LBRACKET':
+ parens.pop()
+ elif t.type == 'RBRACE' and parens[-1] == 'LBRACE':
+ parens.pop()
+ line.append(t)
+ return lines, rules
+
+# Format a list of [lines] according to [rules] (as returned by decompose).
+def compose(lines, rules):
+ code = ''
+ for start, end in rules:
+ for i in range(start, end):
+ line = lines[i]
+ if i > start:
+ code += ' '
+ code += stringify(line)
+ if i == end-1:
+ code += '.\n'
+ elif i == start:
+ code += ' :-\n'
+ else:
+ if line and line[-1].type != 'SEMI' and lines[i+1][-1].type != 'SEMI':
+ code += ','
+ code += '\n'
+ return code.strip()
+
+# Rename variables in [tokens] to A0, A1, A2,… in order of appearance.
+def rename_vars(tokens, names=None):
+ if names is None:
+ names = {}
+ next_id = len(names)
+
+ # Return a new list.
+ tokens = list(tokens)
+ for i in range(len(tokens)):
+ if tokens[i].type == 'PERIOD':
+ names.clear()
+ next_id = 0
+ elif tokens[i] == Token('VARIABLE', '_'):
+ tokens[i] = Token('VARIABLE', 'A{}'.format(next_id))
+ next_id += 1
+ elif tokens[i].type == 'VARIABLE':
+ cur_name = tokens[i].val
+ if cur_name not in names:
+ names[cur_name] = 'A{}'.format(next_id)
+ next_id += 1
+ tokens[i] = Token('VARIABLE', names[cur_name])
+ return tokens
+
+# transformation = before → after; applied on line which is part of rule
+# return mapping from formal vars in before+after to actual vars in rule
+# line and rule should of course not be normalized
+def map_vars(before, after, line, rule):
+ mapping = {}
+ new_index = 0
+ for i in range(len(before)):
+ if line[i].type == 'VARIABLE':
+ formal_name = before[i].val
+ if line[i].val != '_':
+ actual_name = line[i].val
+ else:
+ actual_name = 'New'+str(new_index)
+ new_index += 1
+ mapping[formal_name] = actual_name
+
+ remaining_formal = [t.val for t in after if t.type == 'VARIABLE' and t.val not in mapping.keys()]
+ remaining_actual = [t.val for t in rule if t.type == 'VARIABLE' and t.val != '_' and t.val not in mapping.values()]
+
+ while len(remaining_actual) < len(remaining_formal):
+ remaining_actual.append('New'+str(new_index))
+ new_index += 1
+
+ for i, formal_name in enumerate(remaining_formal):
+ mapping[formal_name] = remaining_actual[i]
+
+ return mapping
+
+# Basic sanity check.
+if __name__ == '__main__':
+ code = 'dup([H|T], [H1|T1]) :- dup(T1, T2). '
+ lines, rules = decompose(code)
+ print(compose(lines, rules))
+
+ var_names = {}
+ before = rename_vars(tokenize("dup([A0|A1], [A2|A3])"), var_names)
+ after = rename_vars(tokenize("dup([A0|A1], [A5, A4|A3])"), var_names)
+
+ line = lines[0]
+ rule = tokenize(code)
+
+ mapping = map_vars(before, after, line, rule)
+ print(mapping)