diff options
Diffstat (limited to 'prolog')
-rw-r--r-- | prolog/util.py | 187 |
1 files changed, 89 insertions, 98 deletions
diff --git a/prolog/util.py b/prolog/util.py index ba61da0..4b316d5 100644 --- a/prolog/util.py +++ b/prolog/util.py @@ -20,12 +20,12 @@ from nltk import Tree # Stores a token's type and value, and optionally the position of the first # character in the lexed stream. -class Token(namedtuple('Token', ['type', 'val', 'pos', 'rule', 'part', 'stop'])): +class Token(namedtuple('Token', ['type', 'val', 'pos'])): __slots__ = () # Custom constructor to support default parameters. - def __new__(cls, type, val='', pos=None, rule=None, part=None, stop=False): - return super(Token, cls).__new__(cls, type, val, pos, rule, part, stop) + def __new__(cls, type, val='', pos=None): + return super(Token, cls).__new__(cls, type, val, pos) def __str__(self): return self.val @@ -45,13 +45,10 @@ class Token(namedtuple('Token', ['type', 'val', 'pos', 'rule', 'part', 'stop'])) return hash(self[1]) # Return a copy of this token, possibly modifying some fields. - def clone(self, type=None, val=None, pos=None, rule=None, part=None, stop=None): + def clone(self, type=None, val=None, pos=None): return Token(self.type if type is None else type, self.val if val is None else val, - self.pos if pos is None else pos, - self.rule if rule is None else rule, - self.part if part is None else part, - self.stop if stop is None else stop) + self.pos if pos is None else pos) from .lexer import lexer, operators from .parser import parser @@ -81,94 +78,6 @@ def stringify(obj): return ''.join([stringify(child) for child in obj]) + '\n' return ''.join([stringify(child) for child in obj]) -# Lex [code] into tokens with rule indexes and stop markers. -def annotate(code): - rule = 0 - part = 0 - parens = [] # stack of currently open parens/brackets/braces - in_parens = 0 # COMMA means a new part if this is 0 - - token = None - lexer.input(code) - for t in lexer: - tok_rule = rule - tok_part = part - tok_stop = True - - if t.type == 'PERIOD': # . - rule += 1 - part = 0 - in_parens = 0 - parens = [] - elif t.type in ('FROM', 'SEMI'): # :- ; - part += 1 - elif t.type == 'COMMA': # , - if not parens or in_parens == 0: - part += 1 - else: - tok_stop = False - - # Handle left parens. - elif t.type == 'LPAREN': # ( - if token and token.type == 'NAME': # name( - tok_stop = False - parens.append('COMPOUND') - in_parens += 1 - else: - parens.append(t.type) # …, ( - elif t.type == 'LBRACKET': # [ - tok_stop = False - parens.append(t.type) - in_parens += 1 - elif t.type == 'LBRACE': # { - parens.append(t.type) - - # Handle right parens. - elif t.type == 'RPAREN': # ) - if parens: - if parens[-1] == 'COMPOUND': # name(…) - tok_stop = False - parens.pop() - in_parens -= 1 - elif parens[-1] == 'LPAREN': # (…) - parens.pop() - elif t.type == 'RBRACKET': # ] - if parens and parens[-1] == 'LBRACKET': # […] - tok_stop = False - parens.pop() - in_parens -= 1 - elif t.type == 'RBRACE': # } - if parens and parens[-1] == 'LBRACE': # {…} - parens.pop() - - # Normal tokens. - else: - tok_stop = False - - token = Token(t.type, t.value, t.lexpos, tok_rule, tok_part, tok_stop) - yield token - -# Format a list of annotated [tokens]. -def compose(tokens): - code = '' - prev = None - for t in tokens: - if t.type == 'SEMI': - code += '\n ' - if prev and (prev.part != t.part or prev.rule != t.rule): - code += '\n' - if t.part: - code += ' ' - - if t.type in ('PERIOD', 'COMMA'): - code += t.val + ' ' - elif t.type in operators.values(): - code += ' ' + t.val + ' ' - else: - code += t.val - prev = t - return code.strip() - # Rename variables in [tokens] to A0, A1, A2,… in order of appearance. def rename_vars(tokens, names=None): if names is None: @@ -179,8 +88,9 @@ def rename_vars(tokens, names=None): tokens = list(tokens) for i, t in enumerate(tokens): if t.type == 'PERIOD': - names.clear() - next_id = 0 + pass +# names.clear() +# next_id = 0 elif t.type == 'VARIABLE': if t.val.startswith('_'): tokens[i] = t.clone(val='A{}'.format(next_id)) @@ -193,6 +103,87 @@ def rename_vars(tokens, names=None): tokens[i] = t.clone(val=names[cur_name]) return tokens +# Rename variables in [tokens] to A0, A1, A2,… in order of appearance. +def rename_vars_list(tokens, names=None): + if names is None: + names = {} + next_id = len(names) + + # Return a new list. + tokens = list(tokens) + for i, t in enumerate(tokens): + if t.type == 'VARIABLE': + if t.val.startswith('_'): + tokens[i] = t.clone(val='A{}'.format(next_id)) + next_id += 1 + else: + cur_name = t.val + if cur_name not in names: + names[cur_name] = 'A{}'.format(next_id) + next_id += 1 + tokens[i] = t.clone(val=names[cur_name]) + return tokens + +# Rename variables in AST rooted at [root] to A0, A1, A2,… in order of +# appearance. +def rename_vars_ast2(root, fixed_names=None): + if fixed_names is None: + fixed_names = {} + names = {} + next_id = len(fixed_names) + len(names) + + def rename_aux(node): + nonlocal fixed_names, names, next_id + if isinstance(node, Tree): + if node.label() == 'clause': + names = {} + next_id = len(fixed_names) + len(names) + new_children = [rename_aux(child) for child in node] + new_node = Tree(node.label(), new_children) + elif isinstance(node, Token): + if node.type == 'VARIABLE': + token = node + if token.val.startswith('_'): + new_node = token.clone(val='A{}'.format(next_id)) + next_id += 1 + else: + cur_name = token.val + if cur_name in fixed_names: + new_name = fixed_names[cur_name] + else: + if cur_name not in names: + names[cur_name] = 'A{}'.format(next_id) + next_id += 1 + new_name = names[cur_name] + new_node = token.clone(val=new_name) + else: + new_node = node + return new_node + return rename_aux(root) + +# Yield "interesting" parts of a Prolog AST as lists of tokens. +def interesting_ranges(ast, path=()): + if ast.label() in {'clause', 'head', 'or', 'if', 'and'}: + if ast.label() == 'and': + for i in range(0, len(ast), 2): + for j in range(i, len(ast), 2): + subs = ast[i:j+1] + terminals = [] + for s in subs: + terminals.extend([s] if isinstance(s, Token) else s.leaves()) + # We want at least some context. + if len(terminals) > 1: + yield terminals, path + (ast.label(),) + else: + terminals = ast.leaves() + # We want at least some context. + if len(terminals) > 1: + yield terminals, path + (ast.label(),) + + for subtree in ast: + if isinstance(subtree, Tree): + yield from interesting_ranges(subtree, path + (ast.label(),)) + # Helper function to remove trailing punctuation from lines and rename # variables to A1,A2,A3,… (potentially using [var_names]). Return a tuple. def normalized(line, var_names=None): |