summaryrefslogtreecommitdiff
path: root/prolog
diff options
context:
space:
mode:
Diffstat (limited to 'prolog')
-rw-r--r--prolog/util.py190
1 files changed, 98 insertions, 92 deletions
diff --git a/prolog/util.py b/prolog/util.py
index 30f12da..e9c1811 100644
--- a/prolog/util.py
+++ b/prolog/util.py
@@ -30,14 +30,14 @@ class Token(namedtuple('Token', ['type', 'val', 'pos', 'rule', 'part', 'stop']))
def __hash__(self):
return hash(self[1])
-# Return a new Token, possibly modifying some fields.
-def clone_token(token, val=None, pos=None, rule=None, part=None):
- return Token(token.type,
- token.val if val is None else val,
- token.pos if pos is None else pos,
- token.rule if rule is None else rule,
- token.part if part is None else part,
- token.stop)
+ # Return a copy of this token, possibly modifying some fields.
+ def clone(self, type=None, val=None, pos=None, rule=None, part=None, stop=None):
+ return Token(self.type if type is None else type,
+ self.val if val is None else val,
+ self.pos if pos is None else pos,
+ self.rule if rule is None else rule,
+ self.part if part is None else part,
+ self.stop if stop is None else stop)
# Return a list of tokens in [text].
def tokenize(text):
@@ -63,83 +63,92 @@ def split(code):
yield stringify(tokens[start:idx])
start = idx + 1
-# Return a list of lines in [code] and a list of rule ranges.
-def decompose(code):
- lines = []
- rules = []
-
- rule_start = 0 # lowest line number in the current rule
- line = [] # tokens in the current line
- break_line = True # for each comma, go to a new line
+# Lex [code] into tokens with rule indexes and stop markers.
+def annotate(code):
+ rule = 0
+ part = 0
parens = [] # stack of currently open parens/brackets/braces
-
- for t in tokenize(code) + [Token('EOF')]:
- # Always break the line on a semicolon, even inside parens.
- if t.type == 'SEMI':
- if line:
- lines.append(tuple(line))
- line = []
- lines.append((t,))
- continue
-
- # Break the line on these tokens if we are not inside parens. Don't
- # append the final token unless it is the :- operator.
- if break_line and t.type in ('PERIOD', 'FROM', 'COMMA', 'EOF'):
- # Only append :- at the end of the line, ignore commas and periods.
- if t.type == 'FROM':
- line.append(t)
-
- # Append nonempty lines to the output list.
- if line:
- lines.append(tuple(line))
- line = []
-
- # Commit a new rule if it contains some lines.
- if t.type in ('PERIOD', 'EOF') and rule_start < len(lines):
- rules.append((rule_start, len(lines)))
- rule_start = len(lines)
- continue
-
- # Handle parens.
- if t.type == 'LPAREN':
- # Disallow breaking lines inside "name( )" (e.g. member(X, L)) but
- # not other ( ).
- if line and line[-1].type == 'NAME':
- parens.append('paren')
- break_line = False
+ in_parens = 0 # COMMA means a new part if this is 0
+
+ token = None
+ lexer.input(code)
+ for t in lexer:
+ tok_rule = rule
+ tok_part = part
+ tok_stop = True
+
+ if t.type == 'PERIOD': # .
+ rule += 1
+ part = 0
+ in_parens = 0
+ parens = []
+ elif t.type in ('FROM', 'SEMI'): # :- ;
+ part += 1
+ elif t.type == 'COMMA': # ,
+ if not parens or in_parens == 0:
+ part += 1
else:
- parens.append('ignore')
- elif t.type in ('LBRACKET', 'LBRACE'):
- # Disallow breaking lines inside "[ ]" and "{ }".
- parens.append('paren')
- break_line = False
- elif parens:
- if t.type in ('RPAREN', 'RBRACE', 'RBRACKET'):
+ tok_stop = False
+
+ # Handle left parens.
+ elif t.type == 'LPAREN': # (
+ if token and token.type == 'NAME': # name(
+ tok_stop = False
+ parens.append('COMPOUND')
+ in_parens += 1
+ else:
+ parens.append(t.type) # …, (
+ elif t.type == 'LBRACKET': # [
+ tok_stop = False
+ parens.append(t.type)
+ in_parens += 1
+ elif t.type == 'LBRACE': # {
+ parens.append(t.type)
+
+ # Handle right parens.
+ elif t.type == 'RPAREN': # )
+ if parens:
+ if parens[-1] == 'COMPOUND': # name(…)
+ tok_stop = False
+ parens.pop()
+ in_parens -= 1
+ elif parens[-1] == 'LPAREN': # (…)
+ parens.pop()
+ elif t.type == 'RBRACKET': # ]
+ if parens and parens[-1] == 'LBRACKET': # […]
+ tok_stop = False
+ parens.pop()
+ in_parens -= 1
+ elif t.type == 'RBRACE': # }
+ if parens and parens[-1] == 'LBRACE': # {…}
parens.pop()
- break_line = 'paren' not in parens
- # Append this token to the current line.
- line.append(t)
+ # Normal tokens.
+ else:
+ tok_stop = False
- return lines, rules
+ token = Token(t.type, t.value, t.lexpos, tok_rule, tok_part, tok_stop)
+ yield token
-# Format a list of [lines] according to [rules] (as returned by decompose).
-def compose(lines, rules):
+# Format a list of annotated [tokens].
+def compose(tokens):
code = ''
- for start, end in rules:
- for i in range(start, end):
- line = lines[i]
- if i > start:
+ prev = None
+ for t in tokens:
+ if t.type == 'SEMI':
+ code += '\n '
+ if prev and (prev.part != t.part or prev.rule != t.rule):
+ code += '\n'
+ if t.part:
code += ' '
- code += stringify(line)
- if i == end-1:
- code += '.\n'
- elif i == start:
- code += '\n'
- else:
- if line and line[-1].type != 'SEMI' and lines[i+1][-1].type != 'SEMI':
- code += ','
- code += '\n'
+
+ if t.type in ('PERIOD', 'COMMA'):
+ code += t.val + ' '
+ elif t.type in operators.values():
+ code += ' ' + t.val + ' '
+ else:
+ code += t.val
+ prev = t
return code.strip()
# Rename variables in [tokens] to A0, A1, A2,… in order of appearance.
@@ -150,19 +159,20 @@ def rename_vars(tokens, names=None):
# Return a new list.
tokens = list(tokens)
- for i in range(len(tokens)):
- if tokens[i].type == 'PERIOD':
+ for i, t in enumerate(tokens):
+ if t.type == 'PERIOD':
names.clear()
next_id = 0
- elif tokens[i] == Token('VARIABLE', '_'):
- tokens[i] = Token('VARIABLE', 'A{}'.format(next_id))
- next_id += 1
- elif tokens[i].type == 'VARIABLE':
- cur_name = tokens[i].val
- if cur_name not in names:
- names[cur_name] = 'A{}'.format(next_id)
+ elif t.type == 'VARIABLE':
+ if t.val.startswith('_'):
+ tokens[i] = t.clone(val='A{}'.format(next_id))
next_id += 1
- tokens[i] = Token('VARIABLE', names[cur_name])
+ else:
+ cur_name = t.val
+ if cur_name not in names:
+ names[cur_name] = 'A{}'.format(next_id)
+ next_id += 1
+ tokens[i] = t.clone(val=names[cur_name])
return tokens
# Helper function to remove trailing punctuation from lines and rename
@@ -203,14 +213,10 @@ def map_vars(a, b, tokens, variables):
for i, formal_name in enumerate(remaining_formal):
mapping[formal_name] = remaining_actual[i]
- return [t if t.type != 'VARIABLE' else clone_token(t, val=mapping[t.val]) for t in b]
+ return [t if t.type != 'VARIABLE' else t.clone(val=mapping[t.val]) for t in b]
# Basic sanity check.
if __name__ == '__main__':
- code = 'dup([H|T], [H1|T1]) :- dup(T1, T2). '
- lines, rules = decompose(code)
- print(compose(lines, rules))
-
var_names = {}
before = rename_vars(tokenize("dup([A0|A1], [A2|A3])"), var_names)
after = rename_vars(tokenize("dup([A0|A1], [A5, A4|A3])"), var_names)