import ast import copy import uuid from .transformations import * from .display import printFunction from .astTools import deepcopy, getAllImportStatements def giveIds(a, idCounter=0): if isinstance(a, ast.AST): if type(a) in [ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param]: return # skip these a.global_id = uuid.uuid1() idCounter += 1 for field in a._fields: child = getattr(a, field) if type(child) == list: for i in range(len(child)): # Get rid of aliased items if hasattr(child[i], "global_id"): child[i] = copy.deepcopy(child[i]) giveIds(child[i], idCounter) else: # Get rid of aliased items if hasattr(child, "global_id"): child = copy.deepcopy(child) setattr(a, field, child) giveIds(child, idCounter) def canonicalize(code, problem_name=None, given_names=None): if given_names is None: given_names = [] tree = ast.parse(code) giveIds(tree) transformations = [ constantFolding, cleanupEquals, cleanupBoolOps, cleanupRanges, cleanupSlices, cleanupTypes, cleanupNegations, conditionalRedundancy, combineConditionals, collapseConditionals, copyPropagation, deMorganize, orderCommutativeOperations, deadCodeRemoval ] varmap = {} tree = propogateMetadata(tree, {'foo': ['int', 'int']}, varmap, [0]) tree = simplify(tree) imports = getAllImportStatements(tree) tree = anonymizeNames(tree, given_names, imports) oldTree = None while compareASTs(oldTree, tree, checkEquality=True) != 0: oldTree = deepcopy(tree) helperFolding(tree, problem_name, imports) for t in transformations: tree = t(tree) new_code = printFunction(tree) return new_code if __name__ == '__main__': code = ''' def isPositive(x): return x > 0 def bar(x, y): return x - y def foo(x, y): if isPositive(y): return x*y if isPositive(x): return x return -x ''' print(canonicalize(code))