summaryrefslogtreecommitdiff
path: root/monkey/graph.py
blob: 7f02c60059b3169ad5e469aaf2df78287bc5b896 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/python3

class Node(object):
    def __init__(self, data, eout=None):
        self.data = data
        self.eout = eout if eout else []

    # (Re-)insert [target] as the right-most child of [self].
    def add_out(self, target, idx=None):
        if target in self.eout:
            self.eout.remove(target)
        self.eout.append(target)
        return target

    # Return a list of nodes in [self].
    def preorder(self):
        nodes = [self]
        for n in self.eout:
            nodes += n.preorder()
        return nodes

    # Return a list of subtrees of [self].
    def subtrees(self):
        yield self
        for child in self.eout:
            yield from child.subtrees()

    # Return the list of leaves' values (left-to-right).
    def terminals(self):
        if not self.eout:
            return [self.data]
        terminals = []
        for child in self.eout:
            terminals.extend(child.terminals())
        return terminals

    # Return a one-line string representation of [self].
    def __str__(self):
        return '(' + str(self.data) + ' ' + \
               ' '.join([str(c) if c.eout else '"'+str(c.data)+'"' for c in self.eout]) + \
               ')'

    def __repr__(self):
        return str(self.data)

    def __lt__(self, other):
        return self.data < other.data

# Print the edit graph containing [nodes] in graphviz dot format. The [label]
# and [pos] functions determine node labels and coordinates (x,y), and the
# [node_attr] and [edge_attr] functions specify additional attributes for each
# node and edge. To actually use the coordinates returned by [pos], generate
# the image using neato -n1.
def graphviz(nodes, label=str, pos=None, node_attr=None, edge_attr=None):
    # Generate node descriptions.
    node_str = ''
    gv_nodes = {}
    for node in nodes:
        gv_nodes[id(node)] = len(gv_nodes)
        node_label = label(node).replace('\\', '\\\\').replace('"', '\\"')
        node_str += '\t{} [label="{}"'.format(gv_nodes[id(node)], node_label)
        if pos:
            node_str += ', ' + 'pos="{},{}"'.format(*pos(node))
        if node_attr:
            node_str += ', ' + node_attr(node)
        node_str += '];\n'

    # Generate edge descriptions (breadth-first).
    edge_str = ''
    for node in nodes:
        a = gv_nodes[id(node)]
        for child in node.eout:
            b = gv_nodes[id(child)]
            edge_str += '\t{} -> {}'.format(a, b)
            if edge_attr:
                edge_str += ' [' + edge_attr(node, child) + ']'
            edge_str += ';\n'

    output = 'digraph G {\n'
    output += '\tordering="out";\n'
    output += '\tnode [shape="box", margin="0.05,0", fontname="sans", fontsize=13.0];\n'
    output += '\n'
    output += node_str
    output += '\n'
    output += edge_str
    output += '}\n'

    return output