summaryrefslogtreecommitdiff
path: root/regex/tree.py
blob: 1ff1b909638758dd15b2aa3fbe605fd9d32f39b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/python3

class Tree(tuple):
    def __new__ (cls, label, children=None, data=None):
        return super(Tree, cls).__new__(cls, tuple(children) if children else tuple())

    def __init__(self, label, children=None, data=None):
        self.label = label
        self.data = data

    def __eq__(self, other):
        return (self.__class__ is other.__class__ and
                self.label == other.label and tuple.__eq__(self, other))
    def __lt__(self, other):
        if isinstance(other, Tree) and self.__class__ is other.__class__:
            return (self.label, tuple(self)) < (other.label, tuple(other))
        return self.__class__.__name__ < other.__class__.__name__
    __ne__ = lambda self, other: not self == other
    __gt__ = lambda self, other: not (self < other or self == other)
    __le__ = lambda self, other: self < other or self == other
    __ge__ = lambda self, other: not self < other

    def __hash__(self):
        return hash((self.label, tuple(self)))

    def __repr__(self):
        return self.to_string()

    # adapted from https://en.wikipedia.org/wiki/S-expression#Parsing
    @staticmethod
    def from_string(string):
        sexp = [[]]
        word = ''
        in_str = False
        for char in string:
            if char == '(' and not in_str:
                sexp.append([])
            elif char == ')' and not in_str:
                if word:
                    sexp[-1].append(word if not sexp[-1] else Tree(word))
                    word = ''
                temp = sexp.pop()
                sexp[-1].append(Tree(temp[0], temp[1:]))
            elif char in (' ', '\n', '\t') and not in_str:
                if word:
                    sexp[-1].append(word if not sexp[-1] else Tree(word))
                    word = ''
            elif char == '\"':
                in_str = not in_str
            else:
                word += char
        return sexp[0][0]

    def to_string(self, inline=False, depth=0):
        if not self:
            return str(self.label)
        string = '(' + self.label
        if not inline and len(self) > 1 and len(list(self.subtrees())) > 8:
            depth += 1
            prefix = '\n' + '  '*depth
        else:
            prefix = ' '
        for i, child in enumerate(self):
            string += prefix
            string += child.to_string(inline, depth)
        string += ')'
        return string

    def to_graphviz(self):
        string = 'digraph {\n'
        string += 'ordering=out\n'
        string += 'ranksep=0.3\n'

        nodes = {}
        edges = []
        for i, (node, parent) in enumerate(self.subtrees_with_parents()):
            nodes[id(node)] = i
            if parent:
                edges += [(id(parent), id(node))]
            string += 'n{} [label="{}"];\n'.format(i, node.label)

        for a, b in edges:
            string += 'n{} -> n{};\n'.format(nodes[a], nodes[b])
        string += '}\n'
        return string

    def subtrees(self):
        yield self
        for child in self:
            yield from child.subtrees()

    def subtrees_with_parents(self, parent=None):
        yield (self, parent)
        for child in self:
            yield from child.subtrees_with_parents(self)