#!/usr/bin/python3 class Tree(tuple): def __new__ (cls, label, children=None, data=None): return super(Tree, cls).__new__(cls, tuple(children) if children else tuple()) def __init__(self, label, children=None, data=None): self.label = label self.data = data def __eq__(self, other): return (self.__class__ is other.__class__ and self.label == other.label and tuple.__eq__(self, other)) def __lt__(self, other): if isinstance(other, Tree) and self.__class__ is other.__class__: return (self.label, tuple(self)) < (other.label, tuple(other)) return self.__class__.__name__ < other.__class__.__name__ __ne__ = lambda self, other: not self == other __gt__ = lambda self, other: not (self < other or self == other) __le__ = lambda self, other: self < other or self == other __ge__ = lambda self, other: not self < other def __hash__(self): return hash((self.label, tuple(self))) def __repr__(self): return self.to_string() # adapted from https://en.wikipedia.org/wiki/S-expression#Parsing @staticmethod def from_string(string): sexp = [[]] word = '' in_str = False for char in string: if char == '(' and not in_str: sexp.append([]) elif char == ')' and not in_str: if word: sexp[-1].append(word if not sexp[-1] else Tree(word)) word = '' temp = sexp.pop() sexp[-1].append(Tree(temp[0], temp[1:])) elif char in (' ', '\n', '\t') and not in_str: if word: sexp[-1].append(word if not sexp[-1] else Tree(word)) word = '' elif char == '\"': in_str = not in_str else: word += char return sexp[0][0] def to_string(self, inline=False, depth=0): if not self: return str(self.label) string = '(' + self.label if not inline and len(self) > 1 and len(list(self.subtrees())) > 8: depth += 1 prefix = '\n' + ' '*depth else: prefix = ' ' for i, child in enumerate(self): string += prefix string += child.to_string(inline, depth) string += ')' return string def to_graphviz(self): string = 'digraph {\n' string += 'ordering=out\n' string += 'ranksep=0.3\n' nodes = {} edges = [] for i, (node, parent) in enumerate(self.subtrees_with_parents()): nodes[id(node)] = i if parent: edges += [(id(parent), id(node))] string += 'n{} [label="{}"];\n'.format(i, node.label) for a, b in edges: string += 'n{} -> n{};\n'.format(nodes[a], nodes[b]) string += '}\n' return string def subtrees(self): yield self for child in self: yield from child.subtrees() def subtrees_with_parents(self, parent=None): yield (self, parent) for child in self: yield from child.subtrees_with_parents(self)