diff options
Diffstat (limited to 'canonicalize/__init__.py')
-rw-r--r-- | canonicalize/__init__.py | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/canonicalize/__init__.py b/canonicalize/__init__.py new file mode 100644 index 0000000..481b6e2 --- /dev/null +++ b/canonicalize/__init__.py @@ -0,0 +1,95 @@ +import ast +import copy +import uuid + +from .transformations import * +from .display import printFunction +from .astTools import deepcopy + +def giveIds(a, idCounter=0): + if isinstance(a, ast.AST): + if type(a) in [ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param]: + return # skip these + a.global_id = uuid.uuid1() + idCounter += 1 + for field in a._fields: + child = getattr(a, field) + if type(child) == list: + for i in range(len(child)): + # Get rid of aliased items + if hasattr(child[i], "global_id"): + child[i] = copy.deepcopy(child[i]) + giveIds(child[i], idCounter) + else: + # Get rid of aliased items + if hasattr(child, "global_id"): + child = copy.deepcopy(child) + setattr(a, field, child) + giveIds(child, idCounter) + +def getCanonicalForm(tree, problem_name=None, given_names=None, argTypes=None, imports=None): + if imports == None: + imports = [] + + transformations = [ + constantFolding, + + cleanupEquals, + cleanupBoolOps, + cleanupRanges, + cleanupSlices, + cleanupTypes, + cleanupNegations, + + conditionalRedundancy, + combineConditionals, + collapseConditionals, + + copyPropagation, + + deMorganize, + orderCommutativeOperations, + + deadCodeRemoval + ] + + varmap = {} + tree = propogateMetadata(tree, {'foo': ['int', 'int']}, varmap, [0]) + + tree = simplify(tree) + tree = anonymizeNames(tree, given_names, imports) + + oldTree = None + while compareASTs(oldTree, tree, checkEquality=True) != 0: + oldTree = deepcopy(tree) + helperFolding(tree, problem_name, imports) + for t in transformations: + tree = t(tree) + return tree + +def canonicalize(code): + tree = ast.parse(code) + giveIds(tree) + + new_tree = getCanonicalForm(tree, problem_name='foo', given_names=['foo']) + new_code = printFunction(new_tree) + + return new_code + + +if __name__ == '__main__': + code = ''' +def isPositive(x): + return x > 0 + +def bar(x, y): + return x - y + +def foo(x, y): + if isPositive(y): + return x*y + if isPositive(x): + return x + return -x +''' + print(canonicalize(code)) |