summaryrefslogtreecommitdiff
path: root/prolog/util.py
blob: fc243f246af119530740a2f343e8601a3928b936 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/python3

import math
import re

from .lexer import lexer

# new lexer stuff
def tokenize(text):
    # feed the troll
    lexer.input(text)
    # we are not interested in line numbers and absolute positions
    return [(t.type, t.value) for t in lexer]

operators = set([
    'FROM', 'IMPLIES', 'NOT',
    'EQU', 'NEQU', 'EQ', 'NEQ', 'UNIV', 'IS', 'EQA', 'NEQA',
    'LT', 'LE', 'GT', 'GE', 'LTL', 'LEL', 'GTL', 'GEL',
    'PLUS', 'MINUS', 'STAR', 'DIV', 'IDIV', 'MOD',
    'POW', 'SEMI'
])
def stringify(tokens, indent=''):
    s = indent
    for t in tokens:
        if t[0] in operators:
            s += ' '

        if t[0] == 'FROM':
            s += ':-\n  ' + indent
        elif t[0] == 'PERIOD':
            s += '.\n' + indent
        elif t[0] == 'COMMA':
            s += ', '
        elif t[0] in operators:
            s += t[1] + ' '
        else:
            s += t[1]

    return s.strip().replace('\n', ' ').replace('\t', ' ')

# return a list of lines in 'code', and a list of rule indexes
def decompose(code):
    lines = []
    rules = []
    tokens = tokenize(code)
    tokens.append(('EOF', ''))

    line = []
    parens = []
    rule_start = 0
    for t in tokens:
        if t[0] == 'SEMI':
            lines.append(line[:])
            lines.append([t])
            line = []
            continue
        if not parens:
            if t[0] in ('PERIOD', 'FROM', 'COMMA', 'EOF'):
                if line != []:
                    lines.append(line[:])
                    line = []
                if t[0] in ('PERIOD', 'EOF'):
                    rules.append((rule_start, len(lines)))
                    rule_start = len(lines)
                continue
        if t[0] in ('LPAREN', 'LBRACKET', 'LBRACE'):
            parens.append(t[0])
        elif parens:
            if t[0] == 'RPAREN' and parens[-1] == 'LPAREN':
                parens.pop()
            elif t[0] == 'RBRACKET' and parens[-1] == 'LBRACKET':
                parens.pop()
            elif t[0] == 'RBRACE' and parens[-1] == 'LBRACE':
                parens.pop()
        line.append(t)
    return lines, rules

# pretty-print a list of rules
def compose(lines, rules):
    code = ''
    for start, end in rules:
        for i in range(start, end):
            line = lines[i]
            if i > start:
                code += '  '
            code += stringify(line).replace('\n', ' ')
            if i == end-1:
                code += '.\n'
            elif i == start:
                code += ' :-\n'
            else:
                if line and line[-1][0] != 'SEMI' and i < end-1 and lines[i+1][-1][0] != 'SEMI':
                    code += ','
                code += '\n'
    return code.strip()

# standardize variable names in order of appearance
def rename_vars(tokens, names={}):
    # copy names so we don't fuck it up
    names = {k: v for k, v in names.items()}
    next_id = len(names)
    for i in range(len(tokens)):
        if tokens[i][0] == 'PERIOD':
            names.clear()
            next_id = 0
        elif tokens[i] == ('VARIABLE', '_'):
            tokens[i] = ('VARIABLE', 'A' + str(next_id))
            next_id += 1
        elif tokens[i][0] == 'VARIABLE':
            cur_name = tokens[i][1]
            if cur_name not in names:
                names[cur_name] = next_id
                next_id += 1
            tokens[i] = ('VARIABLE', 'A' + str(names[cur_name]))
    return names