summaryrefslogtreecommitdiff
path: root/monkey/graph.py
blob: 8929047134bb3e87c9e7a4a344a7d83ce08a9fe1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# CodeQ: an online programming tutor.
# Copyright (C) 2015 UL FRI
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
# details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

class Node(object):
    def __init__(self, data, eout=None):
        self.data = data
        self.eout = eout if eout else []

    # (Re-)insert [target] as the right-most child of [self].
    def add_out(self, target, idx=None):
        if target in self.eout:
            self.eout.remove(target)
        self.eout.append(target)
        return target

    # Return a list of nodes in [self].
    def preorder(self):
        nodes = [self]
        for n in self.eout:
            nodes += n.preorder()
        return nodes

    # Return a list of subtrees of [self].
    def subtrees(self):
        yield self
        for child in self.eout:
            yield from child.subtrees()

    # Return the list of leaves' values (left-to-right).
    def terminals(self):
        if not self.eout:
            return [self.data]
        terminals = []
        for child in self.eout:
            terminals.extend(child.terminals())
        return terminals

    # Return a one-line string representation of [self].
    def __str__(self):
        return '(' + str(self.data) + ' ' + \
               ' '.join([str(c) if c.eout else '"'+str(c.data)+'"' for c in self.eout]) + \
               ')'

    def __repr__(self):
        return str(self.data)

    def __lt__(self, other):
        return self.data < other.data

# Print the edit graph containing [nodes] in graphviz dot format. The [label]
# and [pos] functions determine node labels and coordinates (x,y), and the
# [node_attr] and [edge_attr] functions specify additional attributes for each
# node and edge. To actually use the coordinates returned by [pos], generate
# the image using neato -n1.
def graphviz(nodes, label=str, pos=None, node_attr=None, edge_attr=None):
    # Generate node descriptions.
    node_str = ''
    gv_nodes = {}
    for node in nodes:
        gv_nodes[id(node)] = len(gv_nodes)
        node_label = label(node).replace('\\', '\\\\').replace('"', '\\"')
        node_str += '\t{} [label="{}"'.format(gv_nodes[id(node)], node_label)
        if pos:
            node_str += ', ' + 'pos="{},{}"'.format(*pos(node))
        if node_attr:
            node_str += ', ' + node_attr(node)
        node_str += '];\n'

    # Generate edge descriptions (breadth-first).
    edge_str = ''
    for node in nodes:
        a = gv_nodes[id(node)]
        for child in node.eout:
            b = gv_nodes[id(child)]
            edge_str += '\t{} -> {}'.format(a, b)
            if edge_attr:
                edge_str += ' [' + edge_attr(node, child) + ']'
            edge_str += ';\n'

    output = 'digraph G {\n'
    output += '\tordering="out";\n'
    output += '\tnode [shape="box", margin="0.05,0", fontname="sans", fontsize=13.0];\n'
    output += '\n'
    output += node_str
    output += '\n'
    output += edge_str
    output += '}\n'

    return output