summaryrefslogtreecommitdiff
path: root/canonicalize/__init__.py
blob: 7c4ef799a86a27a5674e03f27509647c58b73b17 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import ast
import copy
import uuid

from .transformations import *
from .display import printFunction
from .astTools import deepcopy, getAllImportStatements

def giveIds(a, idCounter=0):
    if isinstance(a, ast.AST):
        if type(a) in [ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param]:
            return # skip these
        a.global_id = uuid.uuid1()
        idCounter += 1
        for field in a._fields:
            child = getattr(a, field)
            if type(child) == list:
                for i in range(len(child)):
                    # Get rid of aliased items
                    if hasattr(child[i], "global_id"):
                        child[i] = copy.deepcopy(child[i])
                    giveIds(child[i], idCounter)
            else:
                # Get rid of aliased items
                if hasattr(child, "global_id"):
                    child = copy.deepcopy(child)
                    setattr(a, field, child)
                giveIds(child, idCounter)

def canonicalize(code, problem_name=None, given_names=None):
    if given_names is None:
        given_names = []

    tree = ast.parse(code)
    giveIds(tree)

    transformations = [
        constantFolding,

        cleanupEquals,
        cleanupBoolOps,
        cleanupRanges,
        cleanupSlices,
        cleanupTypes,
        cleanupNegations,

        conditionalRedundancy,
        combineConditionals,
        collapseConditionals,

        copyPropagation,

        deMorganize,
        orderCommutativeOperations,

        deadCodeRemoval
    ]

    varmap = {}
    tree = propogateMetadata(tree, {'foo': ['int', 'int']}, varmap, [0])
    tree = simplify(tree)

    imports = getAllImportStatements(tree)
    tree = anonymizeNames(tree, given_names, imports)

    oldTree = None
    while compareASTs(oldTree, tree, checkEquality=True) != 0:
        oldTree = deepcopy(tree)
        helperFolding(tree, problem_name, imports)
        for t in transformations:
            tree = t(tree)

    new_code = printFunction(tree)
    return new_code
        

if __name__ == '__main__':
    code = '''
def isPositive(x):
    return x > 0

def bar(x, y):
    return x - y

def foo(x, y):
    if isPositive(y):
        return x*y
    if isPositive(x):
        return x
    return -x
'''
    print(canonicalize(code))