summaryrefslogtreecommitdiff
path: root/prolog
diff options
context:
space:
mode:
Diffstat (limited to 'prolog')
-rw-r--r--prolog/util.py187
1 files changed, 89 insertions, 98 deletions
diff --git a/prolog/util.py b/prolog/util.py
index ba61da0..4b316d5 100644
--- a/prolog/util.py
+++ b/prolog/util.py
@@ -20,12 +20,12 @@ from nltk import Tree
# Stores a token's type and value, and optionally the position of the first
# character in the lexed stream.
-class Token(namedtuple('Token', ['type', 'val', 'pos', 'rule', 'part', 'stop'])):
+class Token(namedtuple('Token', ['type', 'val', 'pos'])):
__slots__ = ()
# Custom constructor to support default parameters.
- def __new__(cls, type, val='', pos=None, rule=None, part=None, stop=False):
- return super(Token, cls).__new__(cls, type, val, pos, rule, part, stop)
+ def __new__(cls, type, val='', pos=None):
+ return super(Token, cls).__new__(cls, type, val, pos)
def __str__(self):
return self.val
@@ -45,13 +45,10 @@ class Token(namedtuple('Token', ['type', 'val', 'pos', 'rule', 'part', 'stop']))
return hash(self[1])
# Return a copy of this token, possibly modifying some fields.
- def clone(self, type=None, val=None, pos=None, rule=None, part=None, stop=None):
+ def clone(self, type=None, val=None, pos=None):
return Token(self.type if type is None else type,
self.val if val is None else val,
- self.pos if pos is None else pos,
- self.rule if rule is None else rule,
- self.part if part is None else part,
- self.stop if stop is None else stop)
+ self.pos if pos is None else pos)
from .lexer import lexer, operators
from .parser import parser
@@ -81,94 +78,6 @@ def stringify(obj):
return ''.join([stringify(child) for child in obj]) + '\n'
return ''.join([stringify(child) for child in obj])
-# Lex [code] into tokens with rule indexes and stop markers.
-def annotate(code):
- rule = 0
- part = 0
- parens = [] # stack of currently open parens/brackets/braces
- in_parens = 0 # COMMA means a new part if this is 0
-
- token = None
- lexer.input(code)
- for t in lexer:
- tok_rule = rule
- tok_part = part
- tok_stop = True
-
- if t.type == 'PERIOD': # .
- rule += 1
- part = 0
- in_parens = 0
- parens = []
- elif t.type in ('FROM', 'SEMI'): # :- ;
- part += 1
- elif t.type == 'COMMA': # ,
- if not parens or in_parens == 0:
- part += 1
- else:
- tok_stop = False
-
- # Handle left parens.
- elif t.type == 'LPAREN': # (
- if token and token.type == 'NAME': # name(
- tok_stop = False
- parens.append('COMPOUND')
- in_parens += 1
- else:
- parens.append(t.type) # …, (
- elif t.type == 'LBRACKET': # [
- tok_stop = False
- parens.append(t.type)
- in_parens += 1
- elif t.type == 'LBRACE': # {
- parens.append(t.type)
-
- # Handle right parens.
- elif t.type == 'RPAREN': # )
- if parens:
- if parens[-1] == 'COMPOUND': # name(…)
- tok_stop = False
- parens.pop()
- in_parens -= 1
- elif parens[-1] == 'LPAREN': # (…)
- parens.pop()
- elif t.type == 'RBRACKET': # ]
- if parens and parens[-1] == 'LBRACKET': # […]
- tok_stop = False
- parens.pop()
- in_parens -= 1
- elif t.type == 'RBRACE': # }
- if parens and parens[-1] == 'LBRACE': # {…}
- parens.pop()
-
- # Normal tokens.
- else:
- tok_stop = False
-
- token = Token(t.type, t.value, t.lexpos, tok_rule, tok_part, tok_stop)
- yield token
-
-# Format a list of annotated [tokens].
-def compose(tokens):
- code = ''
- prev = None
- for t in tokens:
- if t.type == 'SEMI':
- code += '\n '
- if prev and (prev.part != t.part or prev.rule != t.rule):
- code += '\n'
- if t.part:
- code += ' '
-
- if t.type in ('PERIOD', 'COMMA'):
- code += t.val + ' '
- elif t.type in operators.values():
- code += ' ' + t.val + ' '
- else:
- code += t.val
- prev = t
- return code.strip()
-
# Rename variables in [tokens] to A0, A1, A2,… in order of appearance.
def rename_vars(tokens, names=None):
if names is None:
@@ -179,8 +88,9 @@ def rename_vars(tokens, names=None):
tokens = list(tokens)
for i, t in enumerate(tokens):
if t.type == 'PERIOD':
- names.clear()
- next_id = 0
+ pass
+# names.clear()
+# next_id = 0
elif t.type == 'VARIABLE':
if t.val.startswith('_'):
tokens[i] = t.clone(val='A{}'.format(next_id))
@@ -193,6 +103,87 @@ def rename_vars(tokens, names=None):
tokens[i] = t.clone(val=names[cur_name])
return tokens
+# Rename variables in [tokens] to A0, A1, A2,… in order of appearance.
+def rename_vars_list(tokens, names=None):
+ if names is None:
+ names = {}
+ next_id = len(names)
+
+ # Return a new list.
+ tokens = list(tokens)
+ for i, t in enumerate(tokens):
+ if t.type == 'VARIABLE':
+ if t.val.startswith('_'):
+ tokens[i] = t.clone(val='A{}'.format(next_id))
+ next_id += 1
+ else:
+ cur_name = t.val
+ if cur_name not in names:
+ names[cur_name] = 'A{}'.format(next_id)
+ next_id += 1
+ tokens[i] = t.clone(val=names[cur_name])
+ return tokens
+
+# Rename variables in AST rooted at [root] to A0, A1, A2,… in order of
+# appearance.
+def rename_vars_ast2(root, fixed_names=None):
+ if fixed_names is None:
+ fixed_names = {}
+ names = {}
+ next_id = len(fixed_names) + len(names)
+
+ def rename_aux(node):
+ nonlocal fixed_names, names, next_id
+ if isinstance(node, Tree):
+ if node.label() == 'clause':
+ names = {}
+ next_id = len(fixed_names) + len(names)
+ new_children = [rename_aux(child) for child in node]
+ new_node = Tree(node.label(), new_children)
+ elif isinstance(node, Token):
+ if node.type == 'VARIABLE':
+ token = node
+ if token.val.startswith('_'):
+ new_node = token.clone(val='A{}'.format(next_id))
+ next_id += 1
+ else:
+ cur_name = token.val
+ if cur_name in fixed_names:
+ new_name = fixed_names[cur_name]
+ else:
+ if cur_name not in names:
+ names[cur_name] = 'A{}'.format(next_id)
+ next_id += 1
+ new_name = names[cur_name]
+ new_node = token.clone(val=new_name)
+ else:
+ new_node = node
+ return new_node
+ return rename_aux(root)
+
+# Yield "interesting" parts of a Prolog AST as lists of tokens.
+def interesting_ranges(ast, path=()):
+ if ast.label() in {'clause', 'head', 'or', 'if', 'and'}:
+ if ast.label() == 'and':
+ for i in range(0, len(ast), 2):
+ for j in range(i, len(ast), 2):
+ subs = ast[i:j+1]
+ terminals = []
+ for s in subs:
+ terminals.extend([s] if isinstance(s, Token) else s.leaves())
+ # We want at least some context.
+ if len(terminals) > 1:
+ yield terminals, path + (ast.label(),)
+ else:
+ terminals = ast.leaves()
+ # We want at least some context.
+ if len(terminals) > 1:
+ yield terminals, path + (ast.label(),)
+
+ for subtree in ast:
+ if isinstance(subtree, Tree):
+ yield from interesting_ranges(subtree, path + (ast.label(),))
+
# Helper function to remove trailing punctuation from lines and rename
# variables to A1,A2,A3,… (potentially using [var_names]). Return a tuple.
def normalized(line, var_names=None):