diff options
Diffstat (limited to 'canonicalize')
-rw-r--r-- | canonicalize/COPYING | 21 | ||||
-rw-r--r-- | canonicalize/__init__.py | 95 | ||||
-rw-r--r-- | canonicalize/astTools.py | 1426 | ||||
-rw-r--r-- | canonicalize/display.py | 570 | ||||
-rw-r--r-- | canonicalize/namesets.py | 463 | ||||
-rw-r--r-- | canonicalize/tools.py | 15 | ||||
-rw-r--r-- | canonicalize/transformations.py | 2775 |
7 files changed, 5365 insertions, 0 deletions
diff --git a/canonicalize/COPYING b/canonicalize/COPYING new file mode 100644 index 0000000..2465b95 --- /dev/null +++ b/canonicalize/COPYING @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Kelly Rivers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/canonicalize/__init__.py b/canonicalize/__init__.py new file mode 100644 index 0000000..481b6e2 --- /dev/null +++ b/canonicalize/__init__.py @@ -0,0 +1,95 @@ +import ast +import copy +import uuid + +from .transformations import * +from .display import printFunction +from .astTools import deepcopy + +def giveIds(a, idCounter=0): + if isinstance(a, ast.AST): + if type(a) in [ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param]: + return # skip these + a.global_id = uuid.uuid1() + idCounter += 1 + for field in a._fields: + child = getattr(a, field) + if type(child) == list: + for i in range(len(child)): + # Get rid of aliased items + if hasattr(child[i], "global_id"): + child[i] = copy.deepcopy(child[i]) + giveIds(child[i], idCounter) + else: + # Get rid of aliased items + if hasattr(child, "global_id"): + child = copy.deepcopy(child) + setattr(a, field, child) + giveIds(child, idCounter) + +def getCanonicalForm(tree, problem_name=None, given_names=None, argTypes=None, imports=None): + if imports == None: + imports = [] + + transformations = [ + constantFolding, + + cleanupEquals, + cleanupBoolOps, + cleanupRanges, + cleanupSlices, + cleanupTypes, + cleanupNegations, + + conditionalRedundancy, + combineConditionals, + collapseConditionals, + + copyPropagation, + + deMorganize, + orderCommutativeOperations, + + deadCodeRemoval + ] + + varmap = {} + tree = propogateMetadata(tree, {'foo': ['int', 'int']}, varmap, [0]) + + tree = simplify(tree) + tree = anonymizeNames(tree, given_names, imports) + + oldTree = None + while compareASTs(oldTree, tree, checkEquality=True) != 0: + oldTree = deepcopy(tree) + helperFolding(tree, problem_name, imports) + for t in transformations: + tree = t(tree) + return tree + +def canonicalize(code): + tree = ast.parse(code) + giveIds(tree) + + new_tree = getCanonicalForm(tree, problem_name='foo', given_names=['foo']) + new_code = printFunction(new_tree) + + return new_code + + +if __name__ == '__main__': + code = ''' +def isPositive(x): + return x > 0 + +def bar(x, y): + return x - y + +def foo(x, y): + if isPositive(y): + return x*y + if isPositive(x): + return x + return -x +''' + print(canonicalize(code)) diff --git a/canonicalize/astTools.py b/canonicalize/astTools.py new file mode 100644 index 0000000..b84b2c5 --- /dev/null +++ b/canonicalize/astTools.py @@ -0,0 +1,1426 @@ +import ast, copy, pickle +from .tools import log +from .namesets import * +from .display import printFunction + +def cmp(a, b): + if type(a) == type(b) == complex: + return (a.real > b.real) - (a.real < b.real) + return (a > b) - (a < b) + +def builtInName(id): + """Determines whether the given id is a built-in name""" + if id in builtInNames + exceptionClasses: + return True + elif id in builtInFunctions.keys(): + return True + elif id in list(allPythonFunctions.keys()) + supportedLibraries: + return False + +def importedName(id, importList): + for imp in importList: + if type(imp) == ast.Import: + for name in imp.names: + if hasattr(name, "asname") and name.asname != None: + if id == name.asname: + return True + else: + if id == name.name: + return True + elif type(imp) == ast.ImportFrom: + if hasattr(imp, "module"): + if imp.module in supportedLibraries: + libMap = libraryMap[imp.module] + for name in imp.names: + if hasattr(name, "asname") and name.asname != None: + if id == name.asname: + return True + else: + if id == name.name: + return True + else: + log("astTools\timportedName\tUnsupported library: " + printFunction(imp), "bug") + + else: + log("astTools\timportedName\tWhy no module? " + printFunction(imp), "bug") + return False + +def isConstant(x): + """Determine whether the provided AST is a constant""" + return (type(x) in [ast.Num, ast.Str, ast.Bytes, ast.NameConstant]) + +def isIterableType(t): + """Can the given type be iterated over""" + return t in [ dict, list, set, str, bytes, tuple ] + +def isStatement(a): + """Determine whether the given node is a statement (vs an expression)""" + return type(a) in [ ast.Module, ast.Interactive, ast.Expression, ast.Suite, + ast.FunctionDef, ast.ClassDef, ast.Return, ast.Delete, + ast.Assign, ast.AugAssign, ast.For, ast.While, + ast.If, ast.With, ast.Raise, ast.Try, + ast.Assert, ast.Import, ast.ImportFrom, ast.Global, + ast.Expr, ast.Pass, ast.Break, ast.Continue ] + +def codeLength(a): + """Returns the number of characters in this AST""" + if type(a) == list: + return sum([codeLength(x) for x in a]) + return len(printFunction(a)) + +def applyToChildren(a, f): + """Apply the given function to all the children of a""" + if a == None: + return a + for field in a._fields: + child = getattr(a, field) + if type(child) == list: + i = 0 + while i < len(child): + temp = f(child[i]) + if type(temp) == list: + child = child[:i] + temp + child[i+1:] + i += len(temp) + else: + child[i] = temp + i += 1 + else: + child = f(child) + setattr(a, field, child) + return a + +def occursIn(sub, super): + """Does the first AST occur as a subtree of the second?""" + superStatementTypes = [ ast.Module, ast.Interactive, ast.Suite, + ast.FunctionDef, ast.ClassDef, ast.For, + ast.While, ast.If, ast.With, ast.Try, + ast.ExceptHandler ] + if (not isinstance(super, ast.AST)): + return False + if type(sub) == type(super) and compareASTs(sub, super, checkEquality=True) == 0: + return True + # we know that a statement can never occur in an expression + # (or in a non-statement-holding statement), so cut the search off now to save time. + if isStatement(sub) and type(super) not in superStatementTypes: + return False + for child in ast.iter_child_nodes(super): + if occursIn(sub, child): + return True + return False + +def countOccurances(a, value): + """How many instances of this node type appear in the AST?""" + if type(a) == list: + return sum([countOccurances(x, value) for x in a]) + if not isinstance(a, ast.AST): + return 0 + + count = 0 + for node in ast.walk(a): + if isinstance(node, value): + count += 1 + return count + +def countVariables(a, id): + """Count the number of times the given variable appears in the AST""" + if type(a) == list: + return sum([countVariables(x, id) for x in a]) + if not isinstance(a, ast.AST): + return 0 + + count = 0 + for node in ast.walk(a): + if type(node) == ast.Name and node.id == id: + count += 1 + return count + +def gatherAllNames(a, keep_orig=True): + """Gather all names in the tree (variable or otherwise). + Names are returned along with their original names + (which are used in variable mapping)""" + if type(a) == list: + allIds = set() + for line in a: + allIds |= gatherAllNames(line) + return allIds + if not isinstance(a, ast.AST): + return set() + + allIds = set() + for node in ast.walk(a): + if type(node) == ast.Name: + origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None + allIds |= set([(node.id, origName)]) + return allIds + +def gatherAllVariables(a, keep_orig=True): + """Gather all variable names in the tree. Names are returned along + with their original names (which are used in variable mapping)""" + if type(a) == list: + allIds = set() + for line in a: + allIds |= gatherAllVariables(line) + return allIds + if not isinstance(a, ast.AST): + return set() + + allIds = set() + for node in ast.walk(a): + if type(node) == ast.Name or type(node) == ast.arg: + currentId = node.id if type(node) == ast.Name else node.arg + # Only take variables + if not (builtInName(currentId) or hasattr(node, "dontChangeName")): + origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None + if (currentId, origName) not in allIds: + for pair in allIds: + if pair[0] == currentId: + if pair[1] == None: + allIds -= {pair} + allIds |= {(currentId, origName)} + elif origName == None: + pass + else: + log("astTools\tgatherAllVariables\tConflicting originalIds? " + pair[0] + " : " + pair[1] + " , " + origName + "\n" + printFunction(a), "bug") + break + else: + allIds |= {(currentId, origName)} + return allIds + +def gatherAllParameters(a, keep_orig=True): + """Gather all parameters in the tree. Names are returned along + with their original names (which are used in variable mapping)""" + if type(a) == list: + allIds = set() + for line in a: + allIds |= gatherAllVariables(line) + return allIds + if not isinstance(a, ast.AST): + return set() + + allIds = set() + for node in ast.walk(a): + if type(node) == ast.arg: + origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None + allIds |= set([(node.arg, origName)]) + return allIds + +def gatherAllHelpers(a, restricted_names): + """Gather all helper function names in the tree that have been anonymized""" + if type(a) != ast.Module: + return set() + helpers = set() + for item in a.body: + if type(item) == ast.FunctionDef: + if not hasattr(item, "dontChangeName") and item.name not in restricted_names: # this got anonymized + origName = item.originalId if hasattr(item, "originalId") else None + helpers |= set([(item.name, origName)]) + return helpers + +def gatherAllFunctionNames(a): + """Gather all helper function names in the tree that have been anonymized""" + if type(a) != ast.Module: + return set() + helpers = set() + for item in a.body: + if type(item) == ast.FunctionDef: + origName = item.originalId if hasattr(item, "originalId") else None + helpers |= set([(item.name, origName)]) + return helpers + +def gatherAssignedVars(targets): + """Take a list of assigned variables and extract the names/subscripts/attributes""" + if type(targets) != list: + targets = [targets] + newTargets = [] + for target in targets: + if type(target) in [ast.Tuple, ast.List]: + newTargets += gatherAssignedVars(target.elts) + elif type(target) in [ast.Name, ast.Subscript, ast.Attribute]: + newTargets.append(target) + else: + log("astTools\tgatherAssignedVars\tWeird Assign Type: " + str(type(target)),"bug") + return newTargets + +def gatherAssignedVarIds(targets): + """Just get the ids of Names""" + vars = gatherAssignedVars(targets) + return [y.id for y in filter(lambda x : type(x) == ast.Name, vars)] + +def getAllAssignedVarIds(a): + if not isinstance(a, ast.AST): + return [] + ids = [] + for child in ast.walk(a): + if type(child) == ast.Assign: + ids += gatherAssignedVarIds(child.targets) + elif type(child) == ast.AugAssign: + ids += gatherAssignedVarIds([child.target]) + elif type(child) == ast.For: + ids += gatherAssignedVarIds([child.target]) + return ids + +def getAllAssignedVars(a): + if not isinstance(a, ast.AST): + return [] + vars = [] + for child in ast.walk(a): + if type(child) == ast.Assign: + vars += gatherAssignedVars(child.targets) + elif type(child) == ast.AugAssign: + vars += gatherAssignedVars([child.target]) + elif type(child) == ast.For: + vars += gatherAssignedVars([child.target]) + return vars + +def getAllFunctions(a): + """Collects all the functions in the given module""" + if not isinstance(a, ast.AST): + return [] + functions = [] + for child in ast.walk(a): + if type(child) == ast.FunctionDef: + functions.append(child.name) + return functions + +def getAllImports(a): + """Gather all imported module names""" + if not isinstance(a, ast.AST): + return [] + imports = [] + for child in ast.walk(a): + if type(child) == ast.Import: + for alias in child.names: + if alias.name in supportedLibraries: + imports.append(alias.asname if alias.asname != None else alias.name) + else: + log("astTools\tgetAllImports\tUnknown library: " + alias.name, "bug") + elif type(child) == ast.ImportFrom: + if child.module in supportedLibraries: + for alias in child.names: # these are all functions + if alias.name in libraryMap[child.module]: + imports.append(alias.asname if alias.asname != None else alias.name) + else: + log("astTools\tgetAllImports\tUnknown import from name: " + \ + child.module + "," + alias.name, "bug") + else: + log("astTools\tgetAllImports\tUnknown library: " + child.module, "bug") + return imports + +def getAllImportStatements(a): + if not isinstance(a, ast.AST): + return [] + imports = [] + for child in ast.walk(a): + if type(child) == ast.Import: + imports.append(child) + elif type(child) == ast.ImportFrom: + imports.append(child) + return imports + +def getAllGlobalNames(a): + # Finds all names that can be accessed at the global level in the AST + if type(a) != ast.Module: + return [] + names = [] + for obj in a.body: + if type(obj) in [ast.FunctionDef, ast.ClassDef]: + names.append(obj.name) + elif type(obj) in [ast.Assign, ast.AugAssign]: + targets = obj.targets if type(obj) == ast.Assign else [obj.target] + for target in obj.targets: + if type(target) == ast.Name: + names.append(target.id) + elif type(target) in [ast.Tuple, ast.List]: + for elt in target.elts: + if type(elt) == ast.Name: + names.append(elt.id) + elif type(obj) in [ast.Import, ast.ImportFrom]: + for module in obj.names: + names.append(module.asname if module.asname != None else module.name) + return names + +def doBinaryOp(op, l, r): + """Perform the given AST binary operation on the values""" + top = type(op) + if top == ast.Add: + return l + r + elif top == ast.Sub: + return l - r + elif top == ast.Mult: + return l * r + elif top == ast.Div: + # Don't bother if this will be a really long float- it won't work properly! + # Also, in Python 3 this is floating division, so perform it accordingly. + val = 1.0 * l / r + if (val * 1e10 % 1.0) != 0: + raise Exception("Repeating Float") + return val + elif top == ast.Mod: + return l % r + elif top == ast.Pow: + return l ** r + elif top == ast.LShift: + return l << r + elif top == ast.RShift: + return l >> r + elif top == ast.BitOr: + return l | r + elif top == ast.BitXor: + return l ^ r + elif top == ast.BitAnd: + return l & r + elif top == ast.FloorDiv: + return l // r + +def doUnaryOp(op, val): + """Perform the given AST unary operation on the value""" + top = type(op) + if top == ast.Invert: + return ~ val + elif top == ast.Not: + return not val + elif top == ast.UAdd: + return val + elif top == ast.USub: + return -val + +def doCompare(op, left, right): + """Perform the given AST comparison on the values""" + top = type(op) + if top == ast.Eq: + return left == right + elif top == ast.NotEq: + return left != right + elif top == ast.Lt: + return left < right + elif top == ast.LtE: + return left <= right + elif top == ast.Gt: + return left > right + elif top == ast.GtE: + return left >= right + elif top == ast.Is: + return left is right + elif top == ast.IsNot: + return left is not right + elif top == ast.In: + return left in right + elif top == ast.NotIn: + return left not in right + +def num_negate(op): + top = type(op) + neg = not op.num_negated if hasattr(op, "num_negated") else True + if top == ast.Add: + newOp = ast.Sub() + elif top == ast.Sub: + newOp = ast.Add() + elif top in [ast.Mult, ast.Div, ast.Mod, ast.Pow, ast.LShift, + ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd, ast.FloorDiv]: + return None # can't negate this + elif top in [ast.Num, ast.Name]: + # this is a normal value, so put a - in front of it + newOp = ast.UnaryOp(ast.USub(addedNeg=True), op) + else: + log("astTools\tnum_negate\tUnusual type: " + str(top), "bug") + transferMetaData(op, newOp) + newOp.num_negated = neg + return newOp + +def negate(op): + """Return the negation of the provided operator""" + if op == None: + return None + top = type(op) + neg = not op.negated if hasattr(op, "negated") else True + if top == ast.And: + newOp = ast.Or() + elif top == ast.Or: + newOp = ast.And() + elif top == ast.Eq: + newOp = ast.NotEq() + elif top == ast.NotEq: + newOp = ast.Eq() + elif top == ast.Lt: + newOp = ast.GtE() + elif top == ast.GtE: + newOp = ast.Lt() + elif top == ast.Gt: + newOp = ast.LtE() + elif top == ast.LtE: + newOp = ast.Gt() + elif top == ast.Is: + newOp = ast.IsNot() + elif top == ast.IsNot: + newOp = ast.Is() + elif top == ast.In: + newOp = ast.NotIn() + elif top == ast.NotIn: + newOp = ast.In() + elif top == ast.NameConstant and op.value in [True, False]: + op.value = not op.value + op.negated = neg + return op + elif top == ast.Compare: + if len(op.ops) == 1: + op.ops[0] = negate(op.ops[0]) + op.negated = neg + return op + else: + values = [] + allOperands = [op.left] + op.comparators + for i in range(len(op.ops)): + values.append(ast.Compare(allOperands[i], [negate(op.ops[i])], + [allOperands[i+1]], multiCompPart=True)) + newOp = ast.BoolOp(ast.Or(multiCompOp=True), values, multiComp=True) + elif top == ast.UnaryOp and type(op.op) == ast.Not and \ + eventualType(op.operand) == bool: # this can mess things up type-wise + return op.operand + else: + # this is a normal value, so put a not around it + newOp = ast.UnaryOp(ast.Not(addedNot=True), op) + transferMetaData(op, newOp) + newOp.negated = neg + return newOp + +def couldCrash(a): + """Determines whether the given AST could possibly crash""" + typeCrashes = True # toggle based on whether you care about potential crashes caused by types + if not isinstance(a, ast.AST): + return False + + if type(a) == ast.Try: + for handler in a.handlers: + for child in ast.iter_child_nodes(handler): + if couldCrash(child): + return True + for other in a.orelse: + for child in ast.iter_child_nodes(other): + if couldCrash(child): + return True + for line in a.finalbody: + for child in ast.iter_child_nodes(line): + if couldCrash(child): + return True + return False + + # If any child could crash, this can crash + for child in ast.iter_child_nodes(a): + if couldCrash(child): + return True + + if type(a) == ast.FunctionDef: + argNames = [] + for arg in a.args.args: + if arg.arg in argNames: # conflicting arg names! + return True + else: + argNames.append(arg.arg) + if type(a) == ast.Assign: + for target in a.targets: + if type(target) != ast.Name: # can crash if it's a tuple and we can't unpack the value + return True + elif type(a) in [ast.For, ast.comprehension]: # check if the target or iter will break things + if type(a.target) not in [ast.Name, ast.Tuple, ast.List]: + return True + elif type(a.target) in [ast.Tuple, ast.List]: + for x in a.target.elts: + if type(x) != ast.Name: + return True + elif isIterableType(eventualType(a.iter)): + return True + elif type(a) == ast.Import: + for name in a.names: + if name not in supportedLibraries: + return True + elif type(a) == ast.ImportFrom: + if a.module not in supportedLibraries: + return True + if a.level != None: + return True + for name in a.names: + if name not in libraryMap[a.module]: + return True + elif type(a) == ast.BinOp: + l = eventualType(a.left) + r = eventualType(a.right) + if type(a.op) == ast.Add: + if not ((l == r == str) or (l in [int, float] and r in [int, float])): + return typeCrashes + elif type(a.op) == ast.Mult: + if not ((l == str and r == int) or (l == int and r == str) or \ + (l in [int, float] and r in [int, float])): + return typeCrashes + elif type(a.op) in [ast.Sub, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd]: + if not (l in [int, float] and r in [int, float]): + return typeCrashes + elif type(a.op) == ast.Pow: + if not ((l in [int, float] and r == int) or \ + (l in [int, float] and type(a.right) == ast.Num and \ + type(a.right.n) != complex and \ + (a.right.n >= 1 or a.right.n == 0 or a.right.n <= -1))): + return True + else: # ast.Div, ast.FloorDiv, ast.Mod + if type(a.right) == ast.Num and a.right.n != 0: + if l not in [int, float]: + return typeCrashes + else: + return True # Divide by zero error + elif type(a) == ast.UnaryOp: + if type(a.op) in [ast.UAdd, ast.USub]: + if eventualType(a.operand) not in [int, float]: + return typeCrashes + elif type(a.op) == ast.Invert: + if eventualType(a.operand) != int: + return typeCrashes + elif type(a) == ast.Compare: + if len(a.ops) != len(a.comparators): + return True + elif type(a.ops[0]) in [ast.In, ast.NotIn]: + if not isIterableType(eventualType(a.comparators[0])): + return True + elif eventualType(a.comparators[0]) in [str, bytes] and eventualType(a.left) not in [str, bytes]: + return True + elif type(a.ops[0]) in [ast.Lt, ast.LtE, ast.Gt, ast.GtE]: + # In Python3, you can't compare different types. BOOOOOO!! + firstType = eventualType(a.left) + if firstType == None: + return True + for comp in a.comparators: + if eventualType(comp) != firstType: + return True + elif type(a) == ast.Call: + env = [] # TODO: what if the environments aren't imported? + # First, gather up the needed variables + if type(a.func) == ast.Name: + funName = a.func.id + if funName not in builtInSafeFunctions: + return True + funDict = builtInFunctions + elif type(a.func) == ast.Attribute: + if type(a.func.value) == a.Name and \ + (not hasattr(a.func.value, "varID")) and \ + a.func.value.id in supportedLibraries: + funName = a.func.attr + if funName not in safeLibraryMap(a.func.value.id): + return True + funDict = libraryMap[a.func.value.id] + elif eventualType(a.func.value) == str: + funName = a.func.attr + if funName not in safeStringFunctions: + return True + funDict = builtInStringFunctions + else: # list and dict are definitely crashable + return True + else: + return True + + if funName in ["max", "min"]: + return False # Special functions that have infinite args + + # First, load up the arg types + argTypes = [] + for i in range(len(a.args)): + eventual = eventualType(a.args[i]) + if (eventual == None and typeCrashes): + return True + argTypes.append(eventual) + + if funDict[funName] != None: + for argSet in funDict[funName]: # the given possibilities of arg types + if len(argSet) != len(argTypes): + continue + if not typeCrashes: # If we don't care about types, stop now + return False + + for i in range(len(argSet)): + if not (argSet[i] == argTypes[i] or issubclass(argTypes[i], argSet[i])): + break + else: # if all types matched + return False + return True # Didn't fit any of the options + elif type(a) == ast.Subscript: # can only get an index from a string or list + return eventualType(a.value) not in [str, list, tuple] + elif type(a) == ast.Name: + # If it's an undefined variable, it might crash + if hasattr(a, "randomVar"): + return True + elif type(a) == ast.Slice: + if a.lower != None and eventualType(a.lower) != int: + return True + if a.upper != None and eventualType(a.upper) != int: + return True + if a.step != None and eventualType(a.step) != int: + return True + elif type(a) in [ast.Raise, ast.Assert, ast.Pass, ast.Break, \ + ast.Continue, ast.Yield, ast.Attribute, ast.ExtSlice, ast.Index, \ + ast.Starred]: + # All of these cases can definitely crash. + return True + return False + +def eventualType(a): + """Get the type the expression will eventually be, if possible + The expression might also crash! But we don't care about that here, + we'll deal with it elsewhere. + Returning 'None' means that we cannot say at the moment""" + if type(a) in builtInTypes: + return type(a) + if not isinstance(a, ast.AST): + return None + + elif type(a) == ast.BoolOp: + # In Python, it's the type of all the values in it + # this may work differently in other languages + t = eventualType(a.values[0]) + for i in range(1, len(a.values)): + if eventualType(a.values[i]) != t: + return None + return t + elif type(a) == ast.BinOp: + l = eventualType(a.left) + r = eventualType(a.right) + # It is possible to add/multiply sequences + if type(a.op) in [ast.Add, ast.Mult]: + if isIterableType(l): + return l + elif isIterableType(r): + return r + elif l == float or r == float: + return float + elif l == int and r == int: + return int + return None + elif type(a.op) == ast.Div: + return float # always a float now + # For others, check if we know whether it's a float or an int + elif type(a.op) in [ast.FloorDiv, ast.LShift, ast.RShift, ast.BitOr, + ast.BitAnd, ast.BitXor]: + return int + elif float in [l, r]: + return float + elif l == int and r == int: + return int + else: + return None # Otherwise, it could be a float- we don't know + elif type(a) == ast.UnaryOp: + if type(a.op) == ast.Invert: + return int + elif type(a.op) in [ast.UAdd, ast.USub]: + return eventualType(a.operand) + else: # Not op + return bool + elif type(a) == ast.Lambda: + return function + elif type(a) == ast.IfExp: + l = eventualType(a.body) + r = eventualType(a.orelse) + if l == r: + return l + else: + return None + elif type(a) in [ast.Dict, ast.DictComp]: + return dict + elif type(a) in [ast.Set, ast.SetComp]: + return set + elif type(a) in [ast.List, ast.ListComp]: + return list + elif type(a) == ast.GeneratorExp: + return None # can't represent a generator + elif type(a) == ast.Yield: + return None # we don't know + elif type(a) == ast.Compare: + return bool + elif type(a) == ast.Call: + # Go through our different sets of known functions to see if we know the type + argTypes = [eventualType(x) for x in a.args] + if type(a.func) == ast.Name: + funDict = builtInFunctions + funName = a.func.id + elif type(a.func) == ast.Attribute: + # TODO: get a better solution than this + funName = a.func.attr + if type(a.func.value) == ast.Name and \ + (not hasattr(a.func.value, "varID")) and \ + a.func.value.id in supportedLibraries: + funDict = libraryDictMap[a.func.value.id] + if a.func.value.id in ["string", "str", "list", "dict"] and len(argTypes) > 0: + argTypes.pop(0) # get rid of the first string arg + elif eventualType(a.func.value) == str: + funDict = builtInStringFunctions + elif eventualType(a.func.value) == list: + funDict = builtInListFunctions + elif eventualType(a.func.value) == dict: + funDict = builtInDictFunctions + else: + return None + else: + return None + + if funName in ["max", "min"]: + # If all args are the same type, that's our type + uniqueTypes = set(argTypes) + if len(uniqueTypes) == 1: + return uniqueTypes.pop() + return None + + if funName in funDict and funDict[funName] != None: + possibleTypes = [] + for argSet in funDict[funName]: + if len(argSet) == len(argTypes): + # All types must match! + for i in range(len(argSet)): + if argSet[i] == None or argTypes[i] == None: # We don't know, but that's okay + continue + if not (argSet[i] == argTypes[i] or (issubclass(argTypes[i], argSet[i]))): + break + else: + possibleTypes.append(funDict[funName][argSet]) + possibleTypes = set(possibleTypes) + if len(possibleTypes) == 1: # If there's only one possibility, that's our type! + return possibleTypes.pop() + return None + elif type(a) in [ast.Str, ast.Bytes]: + if containsTokenStepString(a): + return None + return str + elif type(a) == ast.Num: + return type(a.n) + elif type(a) == ast.Attribute: + return None # we have no way of knowing + elif type(a) == ast.Subscript: + # We're slicing the object, so the type will stay the same + t = eventualType(a.value) + if t == None: + return None + elif t == str: + return str # indexing a string + elif t in [list, tuple]: + if type(a.slice) == ast.Slice: + return t + # Otherwise, we need the types of the elements + if type(a.value) in [ast.List, ast.Tuple]: + if len(a.value.elts) == 0: + return None # We don't know + else: + eltType = eventualType(a.value.elts[0]) + for elt in a.value.elts: + if eventualType(elt) != eltType: + return None # Disagreement! + return eltType + elif t in [dict, int]: + return None + else: + log("astTools\teventualType\tUnknown type in subscript: " + str(t), "bug") + return None # We can't know for now... + elif type(a) == ast.NameConstant: + if a.value == True or a.value == False: + return bool + elif a.value == None: + return type(None) + return None + elif type(a) == ast.Name: + if hasattr(a, "type"): # If it's a variable we categorized + return a.type + return None + elif type(a) == ast.Tuple: + return tuple + elif type(a) == ast.Starred: + return None # too complicated + else: + log("astTools\teventualType\tUnimplemented type " + str(type(a)), "bug") + return None + +def depthOfAST(a): + """Determine the depth of the AST""" + if not isinstance(a, ast.AST): + return 0 + m = 0 + for child in ast.iter_child_nodes(a): + tmp = depthOfAST(child) + if tmp > m: + m = tmp + return m + 1 + +def compareASTs(a, b, checkEquality=False): + """A comparison function for ASTs""" + # None before others + if a == b == None: + return 0 + elif a == None or b == None: + return -1 if a == None else 1 + + if type(a) == type(b) == list: + if len(a) != len(b): + return len(a) - len(b) + for i in range(len(a)): + r = compareASTs(a[i], b[i], checkEquality=checkEquality) + if r != 0: + return r + return 0 + + # AST before primitive + if (not isinstance(a, ast.AST)) and (not isinstance(b, ast.AST)): + if type(a) != type(b): + builtins = [bool, int, float, str, bytes, complex] + if type(a) not in builtins or type(b) not in builtins: + log("MISSING BUILT-IN TYPE: " + str(type(a)) + "," + str(type(b)), "bug") + return builtins.index(type(a)) - builtins.index(type(b)) + return cmp(a, b) + elif (not isinstance(a, ast.AST)) or (not isinstance(b, ast.AST)): + return -1 if isinstance(a, ast.AST) else 1 + + # Order by differing types + if type(a) != type(b): + # Here is a brief ordering of types that we care about + blehTypes = [ ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param ] + if type(a) in blehTypes and type(b) in blehTypes: + return 0 + elif type(a) in blehTypes or type(b) in blehTypes: + return -1 if type(a) in blehTypes else 1 + + types = [ ast.Module, ast.Interactive, ast.Expression, ast.Suite, + + ast.Break, ast.Continue, ast.Pass, ast.Global, + ast.Expr, ast.Assign, ast.AugAssign, ast.Return, + ast.Assert, ast.Delete, ast.If, ast.For, ast.While, + ast.With, ast.Import, ast.ImportFrom, ast.Raise, + ast.Try, ast.FunctionDef, + ast.ClassDef, + + ast.BinOp, ast.BoolOp, ast.Compare, ast.UnaryOp, + ast.DictComp, ast.ListComp, ast.SetComp, ast.GeneratorExp, + ast.Yield, ast.Lambda, ast.IfExp, ast.Call, ast.Subscript, + ast.Attribute, ast.Dict, ast.List, ast.Tuple, + ast.Set, ast.Name, ast.Str, ast.Bytes, ast.Num, + ast.NameConstant, ast.Starred, + + ast.Ellipsis, ast.Index, ast.Slice, ast.ExtSlice, + + ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div, + ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr, + ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert, ast.Not, + ast.UAdd, ast.USub, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, + ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, ast.NotIn, + + ast.alias, ast.keyword, ast.arguments, ast.arg, ast.comprehension, + ast.ExceptHandler, ast.withitem + ] + if (type(a) not in types) or (type(b) not in types): + log("astTools\tcompareASTs\tmissing type:" + str(type(a)) + "," + str(type(b)), "bug") + return 0 + return types.index(type(a)) - types.index(type(b)) + + # Then, more complex expressions- but don't bother with this if we're just checking equality + if not checkEquality: + ad = depthOfAST(a) + bd = depthOfAST(b) + if ad != bd: + return bd - ad + + # NameConstants are special + if type(a) == ast.NameConstant: + if a.value == None or b.value == None: + return 1 if a.value != None else (0 if b.value == None else -1) # short and works + + if a.value in [True, False] or b.value in [True, False]: + return 1 if a.value not in [True, False] else (cmp(a.value, b.value) if b.value in [True, False] else -1) + + if type(a) == ast.Name: + return cmp(a.id, b.id) + + # Operations and attributes are all ok + elif type(a) in [ ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div, + ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr, + ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert, + ast.Not, ast.UAdd, ast.USub, ast.Eq, ast.NotEq, ast.Lt, + ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, + ast.NotIn, ast.Load, ast.Store, ast.Del, ast.AugLoad, + ast.AugStore, ast.Param, ast.Ellipsis, ast.Pass, + ast.Break, ast.Continue + ]: + return 0 + + # Now compare based on the attributes in the identical types + for attr in a._fields: + r = compareASTs(getattr(a, attr), getattr(b, attr), checkEquality=checkEquality) + if r != 0: + return r + # If all attributes are identical, they're equal + return 0 + +def deepcopyList(l): + """Deepcopy of a list""" + if l == None: + return None + if isinstance(l, ast.AST): + return deepcopy(l) + if type(l) != list: + log("astTools\tdeepcopyList\tNot a list: " + str(type(l)), "bug") + return copy.deepcopy(l) + + newList = [] + for line in l: + newList.append(deepcopy(line)) + return newList + +def deepcopy(a): + """Let's try to keep this as quick as possible""" + if a == None: + return None + if type(a) == list: + return deepcopyList(a) + elif type(a) in [int, float, str, bool]: + return a + if not isinstance(a, ast.AST): + log("astTools\tdeepcopy\tNot an AST: " + str(type(a)), "bug") + return copy.deepcopy(a) + + g = a.global_id if hasattr(a, "global_id") else None + cp = None + # Objects without lineno, col_offset + if type(a) in [ ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div, + ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr, + ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert, + ast.Not, ast.UAdd, ast.USub, ast.Eq, ast.NotEq, ast.Lt, + ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, + ast.NotIn, ast.Load, ast.Store, ast.Del, ast.AugLoad, + ast.AugStore, ast.Param + ]: + return a + elif type(a) == ast.Module: + cp = ast.Module(deepcopyList(a.body)) + elif type(a) == ast.Interactive: + cp = ast.Interactive(deepcopyList(a.body)) + elif type(a) == ast.Expression: + cp = ast.Expression(deepcopy(a.body)) + elif type(a) == ast.Suite: + cp = ast.Suite(deepcopyList(a.body)) + + elif type(a) == ast.FunctionDef: + cp = ast.FunctionDef(a.name, deepcopy(a.args), deepcopyList(a.body), + deepcopyList(a.decorator_list), deepcopy(a.returns)) + elif type(a) == ast.ClassDef: + cp = ast.ClassDef(a.name, deepcopyList(a.bases), deepcopyList(a.keywords), deepcopyList(a.body), + deepcopyList(a.decorator_list)) + elif type(a) == ast.Return: + cp = ast.Return(deepcopy(a.value)) + elif type(a) == ast.Delete: + cp = ast.Delete(deepcopyList(a.targets)) + elif type(a) == ast.Assign: + cp = ast.Assign(deepcopyList(a.targets), deepcopy(a.value)) + elif type(a) == ast.AugAssign: + cp = ast.AugAssign(deepcopy(a.target), deepcopy(a.op), + deepcopy(a.value)) + elif type(a) == ast.For: + cp = ast.For(deepcopy(a.target), deepcopy(a.iter), + deepcopyList(a.body), deepcopyList(a.orelse)) + elif type(a) == ast.While: + cp = ast.While(deepcopy(a.test), deepcopyList(a.body), + deepcopyList(a.orelse)) + elif type(a) == ast.If: + cp = ast.If(deepcopy(a.test), deepcopyList(a.body), + deepcopyList(a.orelse)) + elif type(a) == ast.With: + cp = ast.With(deepcopyList(a.items),deepcopyList(a.body)) + elif type(a) == ast.Raise: + cp = ast.Raise(deepcopy(a.exc), deepcopy(a.cause)) + elif type(a) == ast.Try: + cp = ast.Try(deepcopyList(a.body), deepcopyList(a.handlers), + deepcopyList(a.orelse), deepcopyList(a.finalbody)) + elif type(a) == ast.Assert: + cp = ast.Assert(deepcopy(a.test), deepcopy(a.msg)) + elif type(a) == ast.Import: + cp = ast.Import(deepcopyList(a.names)) + elif type(a) == ast.ImportFrom: + cp = ast.ImportFrom(a.module, deepcopyList(a.names), a.level) + elif type(a) == ast.Global: + cp = ast.Global(a.names[:]) + elif type(a) == ast.Expr: + cp = ast.Expr(deepcopy(a.value)) + elif type(a) == ast.Pass: + cp = ast.Pass() + elif type(a) == ast.Break: + cp = ast.Break() + elif type(a) == ast.Continue: + cp = ast.Continue() + + elif type(a) == ast.BoolOp: + cp = ast.BoolOp(a.op, deepcopyList(a.values)) + elif type(a) == ast.BinOp: + cp = ast.BinOp(deepcopy(a.left), a.op, deepcopy(a.right)) + elif type(a) == ast.UnaryOp: + cp = ast.UnaryOp(a.op, deepcopy(a.operand)) + elif type(a) == ast.Lambda: + cp = ast.Lambda(deepcopy(a.args), deepcopy(a.body)) + elif type(a) == ast.IfExp: + cp = ast.IfExp(deepcopy(a.test), deepcopy(a.body), deepcopy(a.orelse)) + elif type(a) == ast.Dict: + cp = ast.Dict(deepcopyList(a.keys), deepcopyList(a.values)) + elif type(a) == ast.Set: + cp = ast.Set(deepcopyList(a.elts)) + elif type(a) == ast.ListComp: + cp = ast.ListComp(deepcopy(a.elt), deepcopyList(a.generators)) + elif type(a) == ast.SetComp: + cp = ast.SetComp(deepcopy(a.elt), deepcopyList(a.generators)) + elif type(a) == ast.DictComp: + cp = ast.DictComp(deepcopy(a.key), deepcopy(a.value), + deepcopyList(a.generators)) + elif type(a) == ast.GeneratorExp: + cp = ast.GeneratorExp(deepcopy(a.elt), deepcopyList(a.generators)) + elif type(a) == ast.Yield: + cp = ast.Yield(deepcopy(a.value)) + elif type(a) == ast.Compare: + cp = ast.Compare(deepcopy(a.left), a.ops[:], + deepcopyList(a.comparators)) + elif type(a) == ast.Call: + cp = ast.Call(deepcopy(a.func), deepcopyList(a.args), deepcopyList(a.keywords)) + elif type(a) == ast.Num: + cp = ast.Num(a.n) + elif type(a) == ast.Str: + cp = ast.Str(a.s) + elif type(a) == ast.Bytes: + cp = ast.Bytes(a.s) + elif type(a) == ast.NameConstant: + cp = ast.NameConstant(a.value) + elif type(a) == ast.Attribute: + cp = ast.Attribute(deepcopy(a.value), a.attr, a.ctx) + elif type(a) == ast.Subscript: + cp = ast.Subscript(deepcopy(a.value), deepcopy(a.slice), a.ctx) + elif type(a) == ast.Name: + cp = ast.Name(a.id, a.ctx) + elif type(a) == ast.List: + cp = ast.List(deepcopyList(a.elts), a.ctx) + elif type(a) == ast.Tuple: + cp = ast.Tuple(deepcopyList(a.elts), a.ctx) + elif type(a) == ast.Starred: + cp = ast.Starred(deepcopy(a.value), a.ctx) + + elif type(a) == ast.Slice: + cp = ast.Slice(deepcopy(a.lower), deepcopy(a.upper), deepcopy(a.step)) + elif type(a) == ast.ExtSlice: + cp = ast.ExtSlice(deepcopyList(a.dims)) + elif type(a) == ast.Index: + cp = ast.Index(deepcopy(a.value)) + + elif type(a) == ast.comprehension: + cp = ast.comprehension(deepcopy(a.target), deepcopy(a.iter), + deepcopyList(a.ifs)) + elif type(a) == ast.ExceptHandler: + cp = ast.ExceptHandler(deepcopy(a.type), a.name, deepcopyList(a.body)) + elif type(a) == ast.arguments: + cp = ast.arguments(deepcopyList(a.args), deepcopy(a.vararg), + deepcopyList(a.kwonlyargs), deepcopyList(a.kw_defaults), + deepcopy(a.kwarg), deepcopyList(a.defaults)) + elif type(a) == ast.arg: + cp = ast.arg(a.arg, deepcopy(a.annotation)) + elif type(a) == ast.keyword: + cp = ast.keyword(a.arg, deepcopy(a.value)) + elif type(a) == ast.alias: + cp = ast.alias(a.name, a.asname) + elif type(a) == ast.withitem: + cp = ast.withitem(deepcopy(a.context_expr), deepcopy(a.optional_vars)) + else: + log("astTools\tdeepcopy\tNot implemented: " + str(type(a)), "bug") + cp = copy.deepcopy(a) + + transferMetaData(a, cp) + return cp + +def exportToJson(a): + """Export the ast to json format""" + if a == None: + return "null" + elif type(a) in [int, float]: + return str(a) + elif type(a) == str: + return '"' + a + '"' + elif not isinstance(a, ast.AST): + log("astTools\texportToJson\tMissing type: " + str(type(a)), "bug") + + s = "{\n" + if type(a) in astNames: + s += '"' + astNames[type(a)] + '": {\n' + for field in a._fields: + s += '"' + field + '": ' + value = getattr(a, field) + if type(value) == list: + s += "[" + for item in value: + s += exportToJson(item) + ", " + if len(value) > 0: + s = s[:-2] + s += "]" + else: + s += exportToJson(value) + s += ", " + if len(a._fields) > 0: + s = s[:-2] + s += "}" + else: + log("astTools\texportToJson\tMissing AST type: " + str(type(a)), "bug") + s += "}" + return s + +### ITAP/Canonicalization Functions ### + +def isTokenStepString(s): + """Determine whether this is a placeholder string""" + if len(s) < 2: + return False + return s[0] == "~" and s[-1] == "~" + +def getParentFunction(s): + underscoreSep = s.split("_") + if len(underscoreSep) == 1: + return None + result = "_".join(underscoreSep[1:]) + if result == "newvar" or result == "global": + return None + return result + +def isAnonVariable(s): + """Specificies whether the given string is an anonymized variable name""" + preUnderscore = s.split("_")[0] # the part before the function name + return len(preUnderscore) > 1 and \ + preUnderscore[0] in ["g", "p", "v", "r", "n", "z"] and \ + preUnderscore[1:].isdigit() + +def isDefault(a): + """Our programs have a default setting of return 42, so we should detect that""" + if type(a) == ast.Module and len(a.body) == 1: + a = a.body[0] + else: + return False + + if type(a) != ast.FunctionDef: + return False + + if len(a.body) == 0: + return True + elif len(a.body) == 1: + if type(a.body[0]) == ast.Return: + if a.body[0].value == None or \ + type(a.body[0].value) == ast.Num and a.body[0].value.n == 42: + return True + return False + +def transferMetaData(a, b): + """Transfer the metadata of a onto b""" + properties = [ "global_id", "second_global_id", "lineno", "col_offset", + "originalId", "varID", "variableGlobalId", + "randomVar", "propagatedVariable", "loadedVariable", "dontChangeName", + "reversed", "negated", "inverted", + "augAssignVal", "augAssignBinOp", + "combinedConditional", "combinedConditionalOp", + "multiComp", "multiCompPart", "multiCompMiddle", "multiCompOp", + "addedNot", "addedNotOp", "addedOther", "addedOtherOp", "addedNeg", + "collapsedExpr", "removedLines", + "helperVar", "helperReturnVal", "helperParamAssign", "helperReturnAssign", + "orderedBinOp", "typeCastFunction", "moved_line" ] + for prop in properties: + if hasattr(a, prop): + setattr(b, prop, getattr(a, prop)) + +def assignPropertyToAll(a, prop): + """Assign the provided property to all children""" + if type(a) == list: + for child in a: + assignPropertyToAll(child, prop) + elif isinstance(a, ast.AST): + for node in ast.walk(a): + setattr(node, prop, True) + +def removePropertyFromAll(a, prop): + if type(a) == list: + for child in a: + removePropertyFromAll(child, prop) + elif isinstance(a, ast.AST): + for node in ast.walk(a): + if hasattr(node, prop): + delattr(node, prop) + +def containsTokenStepString(a): + """This is used to keep token-level hint chaining from breaking.""" + if not isinstance(a, ast.AST): + return False + + for node in ast.walk(a): + if type(node) == ast.Str and isTokenStepString(node.s): + return True + return False + +def applyVariableMap(a, variableMap): + if not isinstance(a, ast.AST): + return a + if type(a) == ast.Name: + if a.id in variableMap: + a.id = variableMap[a.id] + elif type(a) in [ast.FunctionDef, ast.ClassDef]: + if a.name in variableMap: + a.name = variableMap[a.name] + return applyToChildren(a, lambda x : applyVariableMap(x, variableMap)) + +def applyHelperMap(a, helperMap): + if not isinstance(a, ast.AST): + return a + if type(a) == ast.Name: + if a.id in helperMap: + a.id = helperMap[a.id] + elif type(a) == ast.FunctionDef: + if a.name in helperMap: + a.name = helperMap[a.name] + return applyToChildren(a, lambda x : applyHelperMap(x, helperMap)) + + +def astFormat(x, gid=None): + """Given a value, turn it into an AST if it's a constant; otherwise, leave it alone.""" + if type(x) in [int, float, complex]: + return ast.Num(x) + elif type(x) == bool or x == None: + return ast.NameConstant(x) + elif type(x) == type: + types = { bool : "bool", int : "int", float : "float", + complex : "complex", str : "str", bytes : "bytes", unicode : "unicode", + list : "list", tuple : "tuple", dict : "dict" } + return ast.Name(types[x], ast.Load()) + elif type(x) == str: # str or unicode + return ast.Str(x) + elif type(x) == bytes: + return ast.Bytes(x) + elif type(x) == list: + elts = [astFormat(val) for val in x] + return ast.List(elts, ast.Load()) + elif type(x) == dict: + keys = [] + vals = [] + for key in x: + keys.append(astFormat(key)) + vals.append(astFormat(x[key])) + return ast.Dict(keys, vals) + elif type(x) == tuple: + elts = [astFormat(val) for val in x] + return ast.Tuple(elts, ast.Load()) + elif type(x) == set: + elts = [astFormat(val) for val in x] + if len(elts) == 0: # needs to be a call instead + return ast.Call(ast.Name("set", ast.Load()), [], []) + else: + return ast.Set(elts) + elif type(x) == slice: + return ast.Slice(astFormat(x.start), astFormat(x.stop), astFormat(x.step)) + elif isinstance(x, ast.AST): + return x # Do not change if it's not constant! + else: + log("astTools\tastFormat\t" + str(type(x)) + "," + str(x),"bug") + return None + +def basicFormat(x): + """Given an AST, turn it into its value if it's constant; otherwise, leave it alone""" + if type(x) == ast.Num: + return x.n + elif type(x) == ast.NameConstant: + return x.value + elif type(x) == ast.Str: + return x.s + elif type(x) == ast.Bytes: + return x.s + return x # Do not change if it's not a constant! + +def structureTree(a): + if type(a) == list: + for i in range(len(a)): + a[i] = structureTree(a[i]) + return a + elif not isinstance(a, ast.AST): + return a + else: + if type(a) == ast.FunctionDef: + a.name = "~name~" + a.args = structureTree(a.args) + a.body = structureTree(a.body) + a.decorator_list = structureTree(a.decorator_list) + a.returns = structureTree(a.returns) + elif type(a) == ast.ClassDef: + a.name = "~name~" + a.bases = structureTree(a.bases) + a.keywords = structureTree(a.keywords) + a.body = structureTree(a.body) + a.decorator_list = structureTree(a.decorator_list) + elif type(a) == ast.AugAssign: + a.target = structureTree(a.target) + a.op = ast.Str("~op~") + a.value = structureTree(a.value) + elif type(a) == ast.Import: + a.names = [ast.Str("~module~")] + elif type(a) == ast.ImportFrom: + a.module = "~module~" + a.names = [ast.Str("~names~")] + elif type(a) == ast.Global: + a.names = ast.Str("~var~") + elif type(a) == ast.BoolOp: + a.op = ast.Str("~op~") + a.values = structureTree(a.values) + elif type(a) == ast.BinOp: + a.op = ast.Str("~op~") + a.left = structureTree(a.left) + a.right = structureTree(a.right) + elif type(a) == ast.UnaryOp: + a.op = ast.Str("~op~") + a.operand = structureTree(a.operand) + elif type(a) == ast.Dict: + return ast.Str("~dictionary~") + elif type(a) == ast.Set: + return ast.Str("~set~") + elif type(a) == ast.Compare: + a.ops = [ast.Str("~op~")]*len(a.ops) + a.left = structureTree(a.left) + a.comparators = structureTree(a.comparators) + elif type(a) == ast.Call: + # leave the function alone + a.args = structureTree(a.args) + a.keywords = structureTree(a.keywords) + elif type(a) == ast.Num: + return ast.Str("~number~") + elif type(a) == ast.Str: + return ast.Str("~string~") + elif type(a) == ast.Bytes: + return ast.Str("~bytes~") + elif type(a) == ast.Attribute: + a.value = structureTree(a.value) + elif type(a) == ast.Name: + a.id = "~var~" + elif type(a) == ast.List: + return ast.Str("~list~") + elif type(a) == ast.Tuple: + return ast.Str("~tuple~") + elif type(a) in [ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div, + ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr, + ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert, + ast.Not, ast.UAdd, ast.USub, ast.Eq, ast.NotEq, + ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, + ast.In, ast.NotIn ]: + return ast.Str("~op~") + elif type(a) == ast.arguments: + a.args = structureTree(a.args) + a.vararg = ast.Str("~arg~") if a.vararg != None else None + a.kwonlyargs = structureTree(a.kwonlyargs) + a.kw_defaults = structureTree(a.kw_defaults) + a.kwarg = ast.Str("~keyword~") if a.kwarg != None else None + a.defaults = structureTree(a.defaults) + elif type(a) == ast.arg: + a.arg = "~arg~" + a.annotation = structureTree(a.annotation) + elif type(a) == ast.keyword: + a.arg = "~keyword~" + a.value = structureTree(a.value) + elif type(a) == ast.alias: + a.name = "~name~" + a.asname = "~asname~" if a.asname != None else None + else: + for field in a._fields: + setattr(a, field, structureTree(getattr(a, field))) + return a + + + diff --git a/canonicalize/display.py b/canonicalize/display.py new file mode 100644 index 0000000..174304f --- /dev/null +++ b/canonicalize/display.py @@ -0,0 +1,570 @@ +import ast +from .tools import log + +#=============================================================================== +# These functions are used for displaying ASTs. printAst displays the tree, +# while printFunction displays the syntax +#=============================================================================== + +# TODO: add AsyncFunctionDef, AsyncFor, AsyncWith, AnnAssign, Nonlocal, Await, YieldFrom, FormattedValue, JoinedStr, Starred + +def printFunction(a, indent=0): + s = "" + if a == None: + return "" + if not isinstance(a, ast.AST): + log("display\tprintFunction\tNot AST: " + str(type(a)) + "," + str(a), "bug") + return str(a) + + t = type(a) + if t in [ast.Module, ast.Interactive, ast.Suite]: + for line in a.body: + s += printFunction(line, indent) + elif t == ast.Expression: + s += printFunction(a.body, indent) + elif t == ast.FunctionDef: + for dec in a.decorator_list: + s += (indent * 4 * " ") + "@" + printFunction(dec, indent) + "\n" + s += (indent * 4 * " ") + "def " + a.name + "(" + \ + printFunction(a.args, indent) + "):\n" + for stmt in a.body: + s += printFunction(stmt, indent+1) + # TODO: returns + elif t == ast.ClassDef: + for dec in a.decorator_list: + s += (indent * 4 * " ") + "@" + printFunction(dec, indent) + "\n" + s += (indent * 4 * " ") + "class " + a.name + if len(a.bases) > 0 or len(a.keywords) > 0: + s += "(" + for base in a.bases: + s += printFunction(base, indent) + ", " + for keyword in a.keywords: + s += printFunction(keyword, indent) + ", " + s += s[:-2] + ")" + s += ":\n" + for stmt in a.body: + s += printFunction(stmt, indent+1) + elif t == ast.Return: + s += (indent * 4 * " ") + "return " + \ + printFunction(a.value, indent) + "\n" + elif t == ast.Delete: + s += (indent * 4 * " ") + "del " + for target in a.targets: + s += printFunction(target, indent) + ", " + if len(a.targets) >= 1: + s = s[:-2] + s += "\n" + elif t == ast.Assign: + s += (indent * 4 * " ") + for target in a.targets: + s += printFunction(target, indent) + " = " + s += printFunction(a.value, indent) + "\n" + elif t == ast.AugAssign: + s += (indent * 4 * " ") + s += printFunction(a.target, indent) + " " + \ + printFunction(a.op, indent) + "= " + \ + printFunction(a.value, indent) + "\n" + elif t == ast.For: + s += (indent * 4 * " ") + s += "for " + \ + printFunction(a.target, indent) + " in " + \ + printFunction(a.iter, indent) + ":\n" + for line in a.body: + s += printFunction(line, indent + 1) + if len(a.orelse) > 0: + s += (indent * 4 * " ") + s += "else:\n" + for line in a.orelse: + s += printFunction(line, indent + 1) + elif t == ast.While: + s += (indent * 4 * " ") + s += "while " + printFunction(a.test, indent) + ":\n" + for line in a.body: + s += printFunction(line, indent + 1) + if len(a.orelse) > 0: + s += (indent * 4 * " ") + s += "else:\n" + for line in a.orelse: + s += printFunction(line, indent + 1) + elif t == ast.If: + s += (indent * 4 * " ") + s += "if " + printFunction(a.test, indent) + ":\n" + for line in a.body: + s += printFunction(line, indent + 1) + branch = a.orelse + # elifs + while len(branch) == 1 and type(branch[0]) == ast.If: + s += (indent * 4 * " ") + s += "elif " + printFunction(branch[0].test, indent) + ":\n" + for line in branch[0].body: + s += printFunction(line, indent + 1) + branch = branch[0].orelse + if len(branch) > 0: + s += (indent * 4 * " ") + s += "else:\n" + for line in branch: + s += printFunction(line, indent + 1) + elif t == ast.With: + s += (indent * 4 * " ") + s += "with " + for item in a.items: + s += printFunction(item, indent) + ", " + if len(a.items) > 0: + s = s[:-2] + s += ":\n" + for line in a.body: + s += printFunction(line, indent + 1) + elif t == ast.Raise: + s += (indent * 4 * " ") + s += "raise" + if a.exc != None: + s += " " + printFunction(a.exc, indent) + # TODO: what is cause?!? + s += "\n" + elif type(a) == ast.Try: + s += (indent * 4 * " ") + "try:\n" + for line in a.body: + s += printFunction(line, indent + 1) + for handler in a.handlers: + s += printFunction(handler, indent) + if len(a.orelse) > 0: + s += (indent * 4 * " ") + "else:\n" + for line in a.orelse: + s += printFunction(line, indent + 1) + if len(a.finalbody) > 0: + s += (indent * 4 * " ") + "finally:\n" + for line in a.finalbody: + s += printFunction(line, indent + 1) + elif t == ast.Assert: + s += (indent * 4 * " ") + s += "assert " + printFunction(a.test, indent) + if a.msg != None: + s += ", " + printFunction(a.msg, indent) + s += "\n" + elif t == ast.Import: + s += (indent * 4 * " ") + "import " + for n in a.names: + s += printFunction(n, indent) + ", " + if len(a.names) > 0: + s = s[:-2] + s += "\n" + elif t == ast.ImportFrom: + s += (indent * 4 * " ") + "from " + s += ("." * a.level if a.level != None else "") + a.module + " import " + for name in a.names: + s += printFunction(name, indent) + ", " + if len(a.names) > 0: + s = s[:-2] + s += "\n" + elif t == ast.Global: + s += (indent * 4 * " ") + "global " + for name in a.names: + s += name + ", " + s = s[:-2] + "\n" + elif t == ast.Expr: + s += (indent * 4 * " ") + printFunction(a.value, indent) + "\n" + elif t == ast.Pass: + s += (indent * 4 * " ") + "pass\n" + elif t == ast.Break: + s += (indent * 4 * " ") + "break\n" + elif t == ast.Continue: + s += (indent * 4 * " ") + "continue\n" + + elif t == ast.BoolOp: + s += "(" + printFunction(a.values[0], indent) + for i in range(1, len(a.values)): + s += " " + printFunction(a.op, indent) + " " + \ + printFunction(a.values[i], indent) + s += ")" + elif t == ast.BinOp: + s += "(" + printFunction(a.left, indent) + s += " " + printFunction(a.op, indent) + " " + s += printFunction(a.right, indent) + ")" + elif t == ast.UnaryOp: + s += "(" + printFunction(a.op, indent) + " " + s += printFunction(a.operand, indent) + ")" + elif t == ast.Lambda: + s += "lambda " + s += printFunction(a.arguments, indent) + ": " + s += printFunction(a.body, indent) + elif t == ast.IfExp: + s += "(" + printFunction(a.body, indent) + s += " if " + printFunction(a.test, indent) + s += " else " + printFunction(a.orelse, indent) + ")" + elif t == ast.Dict: + s += "{ " + for i in range(len(a.keys)): + s += printFunction(a.keys[i], indent) + s += " : " + s += printFunction(a.values[i], indent) + s += ", " + if len(a.keys) >= 1: + s = s[:-2] + s += " }" + elif t == ast.Set: + # Empty sets must be initialized in a special way + if len(a.elts) == 0: + s += "set()" + else: + s += "{" + for elt in a.elts: + s += printFunction(elt, indent) + ", " + s = s[:-2] + s += "}" + elif t == ast.ListComp: + s += "[" + s += printFunction(a.elt, indent) + " " + for gen in a.generators: + s += printFunction(gen, indent) + " " + s = s[:-1] + s += "]" + elif t == ast.SetComp: + s += "{" + s += printFunction(a.elt, indent) + " " + for gen in a.generators: + s += printFunction(gen, indent) + " " + s = s[:-1] + s += "}" + elif t == ast.DictComp: + s += "{" + s += printFunction(a.key, indent) + " : " + \ + printFunction(a.value, indent) + " " + for gen in a.generators: + s += printFunction(gen, indent) + " " + s = s[:-1] + s += "}" + elif t == ast.GeneratorExp: + s += "(" + s += printFunction(a.elt, indent) + " " + for gen in a.generators: + s += printFunction(gen, indent) + " " + s = s[:-1] + s += ")" + elif t == ast.Yield: + s += "yield " + printFunction(a.value, indent) + elif t == ast.Compare: + s += "(" + printFunction(a.left, indent) + for i in range(len(a.ops)): + s += " " + printFunction(a.ops[i], indent) + if i < len(a.comparators): + s += " " + printFunction(a.comparators[i], indent) + if len(a.comparators) > len(a.ops): + for i in range(len(a.ops), len(a.comparators)): + s += " " + printFunction(a.comparators[i], indent) + s += ")" + elif t == ast.Call: + s += printFunction(a.func, indent) + "(" + for arg in a.args: + s += printFunction(arg, indent) + ", " + for key in a.keywords: + s += printFunction(key, indent) + ", " + if len(a.args) + len(a.keywords) >= 1: + s = s[:-2] + s += ")" + elif t == ast.Num: + if a.n != None: + if (type(a.n) == complex) or (type(a.n) != complex and a.n < 0): + s += '(' + str(a.n) + ')' + else: + s += str(a.n) + elif t == ast.Str: + if a.s != None: + val = repr(a.s) + if val[0] == '"': # There must be a single quote in there... + val = "'''" + val[1:len(val)-1] + "'''" + s += val + #s += "'" + a.s.replace("'", "\\'").replace('"', "\\'").replace("\n","\\n") + "'" + elif t == ast.Bytes: + s += str(a.s) + elif t == ast.NameConstant: + s += str(a.value) + elif t == ast.Attribute: + s += printFunction(a.value, indent) + "." + str(a.attr) + elif t == ast.Subscript: + s += printFunction(a.value, indent) + "[" + printFunction(a.slice, indent) + "]" + elif t == ast.Name: + s += a.id + elif t == ast.List: + s += "[" + for elt in a.elts: + s += printFunction(elt, indent) + ", " + if len(a.elts) >= 1: + s = s[:-2] + s += "]" + elif t == ast.Tuple: + s += "(" + for elt in a.elts: + s += printFunction(elt, indent) + ", " + if len(a.elts) > 1: + s = s[:-2] + elif len(a.elts) == 1: + s = s[:-1] # don't get rid of the comma! It clarifies that this is a tuple + s += ")" + elif t == ast.Starred: + s += "*" + printFunction(a.value, indent) + elif t == ast.Ellipsis: + s += "..." + elif t == ast.Slice: + if a.lower != None: + s += printFunction(a.lower, indent) + s += ":" + if a.upper != None: + s += printFunction(a.upper, indent) + if a.step != None: + s += ":" + printFunction(a.step, indent) + elif t == ast.ExtSlice: + for dim in a.dims: + s += printFunction(dim, indent) + ", " + if len(a.dims) > 0: + s = s[:-2] + elif t == ast.Index: + s += printFunction(a.value, indent) + + elif t == ast.comprehension: + s += "for " + s += printFunction(a.target, indent) + " " + s += "in " + s += printFunction(a.iter, indent) + " " + for cond in a.ifs: + s += "if " + s += printFunction(cond, indent) + " " + s = s[:-1] + elif t == ast.ExceptHandler: + s += (indent * 4 * " ") + "except" + if a.type != None: + s += " " + printFunction(a.type, indent) + if a.name != None: + s += " as " + a.name + s += ":\n" + for line in a.body: + s += printFunction(line, indent + 1) + elif t == ast.arguments: + # Defaults are only applied AFTER non-defaults + defaultStart = len(a.args) - len(a.defaults) + for i in range(len(a.args)): + s += printFunction(a.args[i], indent) + if i >= defaultStart: + s += "=" + printFunction(a.defaults[i - defaultStart], indent) + s += ", " + if a.vararg != None: + s += "*" + printFunction(a.vararg, indent) + ", " + if a.kwarg != None: + s += "**" + printFunction(a.kwarg, indent) + ", " + if a.vararg == None and a.kwarg == None and len(a.kwonlyargs) > 0: + s += "*, " + if len(a.kwonlyargs) > 0: + for i in range(len(a.kwonlyargs)): + s += printFunction(a.kwonlyargs[i], indent) + s += "=" + printFunction(a.kw_defaults, indent) + ", " + if (len(a.args) > 0 or a.vararg != None or a.kwarg != None or len(a.kwonlyargs) > 0): + s = s[:-2] + elif t == ast.arg: + s += a.arg + if a.annotation != None: + s += ": " + printFunction(a.annotation, indent) + elif t == ast.keyword: + s += a.arg + "=" + printFunction(a.value, indent) + elif t == ast.alias: + s += a.name + if a.asname != None: + s += " as " + a.asname + elif t == ast.withitem: + s += printFunction(a.context_expr, indent) + if a.optional_vars != None: + s += " as " + printFunction(a.optional_vars, indent) + else: + ops = { ast.And : "and", ast.Or : "or", + ast.Add : "+", ast.Sub : "-", ast.Mult : "*", ast.Div : "/", ast.Mod : "%", + ast.Pow : "**", ast.LShift : "<<", ast.RShift : ">>", ast.BitOr : "|", + ast.BitXor : "^", ast.BitAnd : "&", ast.FloorDiv : "//", + ast.Invert : "~", ast.Not : "not", ast.UAdd : "+", ast.USub : "-", + ast.Eq : "==", ast.NotEq : "!=", ast.Lt : "<", ast.LtE : "<=", + ast.Gt : ">", ast.GtE : ">=", ast.Is : "is", ast.IsNot : "is not", + ast.In : "in", ast.NotIn : "not in"} + if type(a) in ops: + return ops[type(a)] + if type(a) in [ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param]: + return "" + log("display\tMissing type: " + str(t), "bug") + return s + +def formatContext(trace, verb): + traceD = { + "value" : { "Return" : ("return statement"), + "Assign" : ("right side of the assignment"), + "AugAssign" : ("right side of the assignment"), + "Expression" : ("expression"), + "Dict Comprehension" : ("left value of the dict comprehension"), + "Yield" : ("yield expression"), + "Repr" : ("repr expression"), + "Attribute" : ("attribute value"), + "Subscript" : ("outer part of the subscript"), + "Index" : ("inner part of the subscript"), + "Keyword" : ("right side of the keyword"), + "Starred" : ("value of the starred expression"), + "Name Constant" : ("constant value") }, + "values" : { "Print" : ("print statement"), + "Boolean Operation" : ("boolean operation"), + "Dict" : ("values of the dictionary") }, + "name" : { "Function Definition" : ("function name"), + "Class Definition" : ("class name"), + "Except Handler" : ("name of the except statement"), + "Alias" : ("alias") }, + "names" : { "Import" : ("import"), + "ImportFrom" : ("import"), + "Global" : ("global variables") }, + "elt" : { "List Comprehension" : ("left element of the list comprehension"), + "Set Comprehension" : ("left element of the set comprehension"), + "Generator" : ("left element of the generator") }, + "elts" : { "Set" : ("set"), + "List" : ("list"), + "Tuple" : ("tuple") }, + "target" : { "AugAssign" : ("left side of the assignment"), + "For" : ("target of the for loop"), + "Comprehension" : ("target of the comprehension") }, + "targets" : { "Delete" : ("delete statement"), + "Assign" : ("left side of the assignment") }, + "op" : { "AugAssign" : ("assignment"), + "Boolean Operation" : ("boolean operation"), + "Binary Operation" : ("binary operation"), + "Unary Operation" : ("unary operation") }, + "ops" : { "Compare" : ("comparison operation") }, + "arg" : { "Keyword" : ("left side of the keyword"), + "Argument" : ("argument") }, + "args" : { "Function Definition" : ("function arguments"), # single item + "Lambda" : ("lambda arguments"), # single item + "Call" : ("arguments of the function call"), + "Arguments" : ("function arguments") }, + "key" : { "Dict Comprehension" : ("left key of the dict comprehension") }, + "keys" : { "Dict" : ("keys of the dictionary") }, + "kwarg" : { "Arguments" : ("keyword arg") }, + "kwargs" : { "Call" : ("keyword args of the function call") }, # single item + "body" : { "Module" : ("main codebase"), # list + "Interactive" : ("main codebase"), # list + "Expression" : ("main codebase"), + "Suite" : ("main codebase"), # list + "Function Definition" : ("function body"), # list + "Class Definition" : ("class body"), # list + "For" : ("lines of the for loop"), # list + "While" : ("lines of the while loop"), # list + "If" : ("main lines of the if statement"), # list + "With" : ("lines of the with block"), # list + "Try" : ("lines of the try block"), # list + "Execute" : ("exec expression"), + "Lambda" : ("lambda body"), + "Ternary" : ("ternary body"), + "Except Handler" : ("lines of the except block") }, # list + "orelse" : { "For" : ("else part of the for loop"), # list + "While" : ("else part of the while loop"), # list + "If" : ("lines of the else statement"), # list + "Try" : ("lines of the else statement"), # list + "Ternary" : ("ternary else value") }, + "test" : { "While" : ("test case of the while statement"), + "If" : ("test case of the if statement"), + "Assert" : ("assert expression"), + "Ternary" : ("test case of the ternary expression") }, + "generators" : { "List Comprehension" : ("list comprehension"), + "Set Comprehension" : ("set comprehension"), + "Dict Comprehension" : ("dict comprehension"), + "Generator" : ("generator") }, + "decorator_list" : { "Function Definition" : ("function decorators"), # list + "Class Definition" : ("class decorators") }, # list + "iter" : { "For" : ("iterator of the for loop"), + "Comprehension" : ("iterator of the comprehension") }, + "type" : { "Raise" : ("raised type"), + "Except Handler" : ("type of the except statement") }, + "left" : { "Binary Operation" : ("left side of the binary operation"), + "Compare" : ("left side of the comparison") }, + "bases" : { "Class Definition" : ("class bases") }, + "dest" : { "Print" : ("print destination") }, + "nl" : { "Print" : ("comma at the end of the print statement") }, + "context_expr" : { "With item" : ("context of the with statement") }, + "optional_vars" : { "With item" : ("context of the with statement") }, # single item + "inst" : { "Raise" : ("raise expression") }, + "tback" : { "Raise" : ("raise expression") }, + "handlers" : { "Try" : ("except block") }, + "finalbody" : { "Try" : ("finally block") }, # list + "msg" : { "Assert" : ("assert message") }, + "module" : { "Import From" : ("import module") }, + "level" : { "Import From" : ("import module") }, + "globals" : { "Execute" : ("exec global value") }, # single item + "locals" : { "Execute" : ("exec local value") }, # single item + "right" : { "Binary Operation" : ("right side of the binary operation") }, + "operand" : { "Unary Operation" : ("value of the unary operation") }, + "comparators" : { "Compare" : ("right side of the comparison") }, + "func" : { "Call" : ("function call") }, + "keywords" : { "Call" : ("keywords of the function call") }, + "starargs" : { "Call" : ("star args of the function call") }, # single item + "attr" : { "Attribute" : ("attribute of the value") }, + "slice" : { "Subscript" : ("inner part of the subscript") }, + "lower" : { "Slice" : ("left side of the subscript slice") }, + "upper" : { "Slice" : ("right side of the subscript slice") }, + "step" : { "Step" : ("rightmost side of the subscript slice") }, + "dims" : { "ExtSlice" : ("slice") }, + "ifs" : { "Comprehension" : ("if part of the comprehension") }, + "vararg" : { "Arguments" : ("vararg") }, + "defaults" : { "Arguments" : ("default values of the arguments") }, + "asname" : { "Alias" : ("new name") }, + "items" : { "With" : ("context of the with statement") } + } + + # Find what type this is by trying to find the closest container in the path + i = 0 + while i < len(trace): + if type(trace[i]) == tuple: + if trace[i][0] == "value" and trace[i][1] == "Attribute": + pass + elif trace[i][0] in traceD: + break + elif trace[i][0] in ["id", "n", "s"]: + pass + else: + log("display\tformatContext\tSkipped field: " + str(trace[i]), "bug") + i += 1 + else: + return "" # this is probably covered by the line number + + field,typ = trace[i] + if field in traceD and typ in traceD[field]: + context = traceD[field][typ] + return verb + "the " + context + else: + log("display\tformatContext\tMissing field: " + str(field) + "," + str(typ), "bug") + return "" + +def formatList(node, field): + if type(node) != list: + return None + s = "" + nameMap = { "body" : "line", "targets" : "value", "values" : "value", "orelse" : "line", + "names" : "name", "keys" : "key", "elts" : "value", "ops" : "operator", + "comparators" : "value", "args" : "argument", "keywords" : "keyword" } + + # Find what type this is + itemType = nameMap[field] if field in nameMap else "line" + + if len(node) > 1: + s = "the " + itemType + "s: " + for line in node: + s += formatNode(line) + ", " + elif len(node) == 1: + s = "the " + itemType + " " + f = formatNode(node[0]) + if itemType == "line": + f = "[" + f + "]" + s += f + return s + +def formatNode(node): + """Create a string version of the given node""" + if node == None: + return "" + t = type(node) + if t == str: + return "'" + node + "'" + elif t == int or t == float: + return str(node) + elif t == list: + return formatList(node, None) + else: + return printFunction(node, 0)
\ No newline at end of file diff --git a/canonicalize/namesets.py b/canonicalize/namesets.py new file mode 100644 index 0000000..5523f44 --- /dev/null +++ b/canonicalize/namesets.py @@ -0,0 +1,463 @@ +import ast, collections + +supportedLibraries = [ "string", "math", "random", "__future__", "copy" ] + +builtInTypes = [ bool, bytes, complex, dict, float, int, list, + set, str, tuple, type ] + +staticTypeCastBuiltInFunctions = { + "bool" : { (object,) : bool }, + "bytes" : bytes, + "complex" : { (str, int) : complex, (str, float) : complex, (int, int) : complex, + (int, float) : complex, (float, int) : complex, (float, float) : complex }, + "dict" : { (collections.Iterable,) : dict }, + "enumerate" : { (collections.Iterable,) : enumerate }, + "float" : { (str,) : float, (int,) : float, (float,) : float }, + "frozenset" : frozenset, + "int" : { (str,) : int, (float,) : int, (int,) : int, (str, int) : int }, + "list" : { (collections.Iterable,) : list }, + "memoryview" : None, #TODO + "object" : { () : object }, + "property" : property, #TODO + "reversed" : { (str,) : reversed, (list,) : reversed }, + "set" : { () : None, (collections.Iterable,) : None }, #TODO + "slice" : { (int,) : slice, (int, int) : slice, (int, int, int) : slice }, + "str" : { (object,) : str }, + "tuple" : { () : tuple, (collections.Iterable,) : tuple }, + "type" : { (object,) : type }, + } + +mutatingTypeCastBuiltInFunctions = { + "bytearray" : { () : list }, + "classmethod" : None, + "file" : None, + "staticmethod" : None, #TOOD + "super" : None + } + +builtInNames = [ "None", "True", "False", "NotImplemented", "Ellipsis" ] + +staticBuiltInFunctions = { + "abs" : { (int,) : int, (float,) : float }, + "all" : { (collections.Iterable,) : bool }, + "any" : { (collections.Iterable,) : bool }, + "bin" : { (int,) : str }, + "callable" : { (object,) : bool }, + "chr" : { (int,) : str }, + "cmp" : { (object, object) : int }, + "coerce" : tuple, #TODO + "compile" : { (str, str, str) : ast, (ast, str, str) : ast }, + "dir" : { () : list }, + "divmod" : { (int, int) : tuple, (int, float) : tuple, + (float, int) : tuple, (float, float) : tuple }, + "filter" : { (type(lambda x : x), collections.Iterable) : list }, + "getattr" : None, + "globals" : dict, + "hasattr" : bool, #TODO + "hash" : int, + "hex" : str, + "id" : int, #TODO + "isinstance" : { (None, None) : bool }, + "issubclass" : bool, #TODO + "iter" : { (collections.Iterable,) : None, (None, object) : None }, + "len" : { (str,) : int, (tuple,) : int, (list,) : int, (dict,) : int }, + "locals" : dict, + "map" : { (None, collections.Iterable) : list }, #TODO + "max" : { (collections.Iterable,) : None }, + "min" : { (collections.Iterable,) : None }, + "oct" : { (int,) : str }, + "ord" : { (str,) : int }, + "pow" : { (int, int) : int, (int, float) : float, + (float, int) : float, (float, float) : float }, + "print" : None, + "range" : { (int,) : list, (int, int) : list, (int, int, int) : list }, + "repr" : {(object,) : str }, + "round" : { (int,) : float, (float,) : float, (int, int) : float, (float, int) : float }, + "sorted" : { (collections.Iterable,) : list }, + "sum" : { (collections.Iterable,) : None }, + "vars" : dict, #TODO + "zip" : { () : list, (collections.Iterable,) : list} + } + +mutatingBuiltInFunctions = { + "__import__" : None, + "apply" : None, + "delattr" : { (object, str) : None }, + "eval" : { (str,) : None }, + "execfile" : None, + "format" : None, + "input" : { () : None, (object,) : None }, + "intern" : str, #TOOD + "next" : { () : None, (None,) : None }, + "open" : None, + "raw_input" : { () : str, (object,) : str }, + "reduce" : None, + "reload" : None, + "setattr" : None + } + +builtInSafeFunctions = [ + "abs", "all", "any", "bin", "bool", "cmp", "len", + "list", "max", "min", "pow", "repr", "round", "slice", "str", "type" + ] + + +exceptionClasses = [ + "ArithmeticError", + "AssertionError", + "AttributeError", + "BaseException", + "BufferError", + "BytesWarning", + "DeprecationWarning", + "EOFError", + "EnvironmentError", + "Exception", + "FloatingPointError", + "FutureWarning", + "GeneratorExit", + "IOError", + "ImportError", + "ImportWarning", + "IndentationError", + "IndexError", + "KeyError", + "KeyboardInterrupt", + "LookupError", + "MemoryError", + "NameError", + "NotImplementedError", + "OSError", + "OverflowError", + "PendingDeprecationWarning", + "ReferenceError", + "RuntimeError", + "RuntimeWarning", + "StandardError", + "StopIteration", + "SyntaxError", + "SyntaxWarning", + "SystemError", + "SystemExit", + "TabError", + "TypeError", + "UnboundLocalError", + "UnicodeDecodeError", + "UnicodeEncodeError", + "UnicodeError", + "UnicodeTranslateError", + "UnicodeWarning", + "UserWarning", + "ValueError", + "Warning", + "ZeroDivisionError", + + "WindowsError", "BlockingIOError", "ChildProcessError", + "ConnectionError", "BrokenPipeError", "ConnectionAbortedError", + "ConnectionRefusedError", "ConnectionResetError", "FileExistsError", + "FileNotFoundError", "InterruptedError", "IsADirectoryError", "NotADirectoryError", + "PermissionError", "ProcessLookupError", "TimeoutError", + "ResourceWarning", "RecursionError", "StopAsyncIteration" ] + +builtInFunctions = dict(list(staticBuiltInFunctions.items()) + \ + list(mutatingBuiltInFunctions.items()) + \ + list(staticTypeCastBuiltInFunctions.items()) + \ + list(mutatingTypeCastBuiltInFunctions.items())) + +# All string functions do not mutate the caller, they return copies instead +builtInStringFunctions = { + "capitalize" : { () : str }, + "center" : { (int,) : str, (int, str) : str }, + "count" : { (str,) : int, (str, int) : int, (str, int, int) : int }, + "decode" : { () : str }, + "encode" : { () : str }, + "endswith" : { (str,) : bool }, + "expandtabs" : { () : str, (int,) : str }, + "find" : { (str,) : int, (str, int) : int, (str, int, int) : int }, + "format" : { (list, list) : str }, + "index" : { (str,) : int, (str,int) : int, (str,int,int) : int }, + "isalnum" : { () : bool }, + "isalpha" : { () : bool }, + "isdecimal" : { () : bool }, + "isdigit" : { () : bool }, + "islower" : { () : bool }, + "isnumeric" : { () : bool }, + "isspace" : { () : bool }, + "istitle" : { () : bool }, + "isupper" : { () : bool }, + "join" : { (collections.Iterable,) : str, (collections.Iterable,str) : str }, + "ljust" : { (int,) : str }, + "lower" : { () : str }, + "lstrip" : { () : str, (str,) : str }, + "partition" : { (str,) : tuple }, + "replace" : { (str, str) : str, (str, str, int) : str }, + "rfind" : { (str,) : int, (str,int) : int, (str,int,int) : int }, + "rindex" : { (str,) : int }, + "rjust" : { (int,) : str }, + "rpartition" : { (str,) : tuple }, + "rsplit" : { () : list }, + "rstrip" : { () : str }, + "split" : { () : list, (str,) : list, (str, int) : list }, + "splitlines" : { () : list }, + "startswith" : { (str,) : bool }, + "strip" : { () : str, (str,) : str }, + "swapcase" : { () : str }, + "title" : { () : str }, + "translate" : { (str,) : str }, + "upper" : { () : str }, + "zfill" : { (int,) : str } + } + +safeStringFunctions = [ + "capitalize", "center", "count", "endswith", "expandtabs", "find", + "isalnum", "isalpha", "isdigit", "islower", "isspace", "istitle", + "isupper", "join", "ljust", "lower", "lstrip", "partition", "replace", + "rfind", "rjust", "rpartition", "rsplit", "rstrip", "split", "splitlines", + "startswith", "strip", "swapcase", "title", "translate", "upper", "zfill", + "isdecimal", "isnumeric"] + +mutatingListFunctions = { + "append" : { (object,) : None }, + "extend" : { (list,) : None }, + "insert" : { (int, object) : None }, + "remove" : { (object,) : None }, + "pop" : { () : None, (int,) : None }, + "sort" : { () : None }, + "reverse" : { () : None } + } + +staticListFunctions = { + "index" : { (object,) : int, (object,int) : int, (object,int,int) : int }, + "count" : { (object,) : int, (object,int) : int, (object,int,int) : int } + } + +safeListFunctions = [ "append", "extend", "insert", "count", "sort", "reverse"] + +builtInListFunctions = dict(list(mutatingListFunctions.items()) + list(staticListFunctions.items())) + +staticDictFunctions = { + "get" : { (object,) : object, (object, object) : object }, + "items" : { () : list } + } + +builtInDictFunctions = staticDictFunctions + +mathFunctions = { + "ceil" : { (int,) : float, (float,) : float }, + "copysign" : { (int, int) : float, (int, float) : float, + (float, int) : float, (float, float) : float }, + "fabs" : { (int,) : float, (float,) : float }, + "factorial" : { (int,) : int, (float,) : int }, + "floor" : { (int,) : float, (float,) : float }, + "fmod" : { (int, int) : float, (int, float) : float, + (float, int) : float, (float, float) : float }, + "frexp" : int, + "fsum" : int, #TODO + "isinf" : { (int,) : bool, (float,) : bool }, + "isnan" : { (int,) : bool, (float,) : bool }, + "ldexp" : int, + "modf" : tuple, + "trunc" : None, #TODO + "exp" : { (int,) : float, (float,) : float }, + "expm1" : { (int,) : float, (float,) : float }, + "log" : { (int,) : float, (float,) : float, + (int,int) : float, (int,float) : float, + (float, int) : float, (float, float) : float }, + "log1p" : { (int,) : float, (float,) : float }, + "log10" : { (int,) : float, (float,) : float }, + "pow" : { (int, int) : float, (int, float) : float, + (float, int) : float, (float, float) : float }, + "sqrt" : { (int,) : float, (float,) : float }, + "acos" : { (int,) : float, (float,) : float }, + "asin" : { (int,) : float, (float,) : float }, + "atan" : { (int,) : float, (float,) : float }, + "atan2" : { (int,) : float, (float,) : float }, + "cos" : { (int,) : float, (float,) : float }, + "hypot" : { (int, int) : float, (int, float) : float, + (float, int) : float, (float, float) : float }, + "sin" : { (int,) : float, (float,) : float }, + "tan" : { (int,) : float, (float,) : float }, + "degrees" : { (int,) : float, (float,) : float }, + "radians" : { (int,) : float, (float,) : float }, + "acosh" : int, + "asinh" : int, + "atanh" : int, + "cosh" : int, + "sinh" : int, + "tanh" : int,#TODO + "erf" : int, + "erfc" : int, + "gamma" : int, + "lgamma" : int #TODO + } + +safeMathFunctions = [ + "ceil", "copysign", "fabs", "floor", "fmod", "isinf", + "isnan", "exp", "expm1", "cos", "hypot", "sin", "tan", + "degrees", "radians" ] + +randomFunctions = { + "seed" : { () : None, (collections.Hashable,) : None }, + "getstate" : { () : object }, + "setstate" : { (object,) : None }, + "jumpahead" : { (int,) : None }, + "getrandbits" : { (int,) : int }, + "randrange" : { (int,) : int, (int, int) : int, (int, int, int) : int }, + "randint" : { (int, int) : int }, + "choice" : { (collections.Iterable,) : object }, + "shuffle" : { (collections.Iterable,) : None, + (collections.Iterable, type(lambda x : x)) : None }, + "sample" : { (collections.Iterable, int) : list }, + "random" : { () : float }, + "uniform" : { (float, float) : float } + } + +futureFunctions = { + "nested_scopes" : None, + "generators" : None, + "division" : None, + "absolute_import" : None, + "with_statement" : None, + "print_function" : None, + "unicode_literals" : None + } + +copyFunctions = { + "copy" : None, + "deepcopy" : None +} + +timeFunctions = { + "clock" : { () : float }, + "time" : { () : float } + } + +errorFunctions = { + "AssertionError" : { (str,) : object } + } + +allStaticFunctions = dict(list(staticBuiltInFunctions.items()) + list(staticTypeCastBuiltInFunctions.items()) + \ + list(builtInStringFunctions.items()) + list(staticListFunctions.items()) + \ + list(staticDictFunctions.items()) + list(mathFunctions.items())) + +allMutatingFunctions = dict(list(mutatingBuiltInFunctions.items()) + list(mutatingTypeCastBuiltInFunctions.items()) + \ + list(mutatingListFunctions.items()) + list(randomFunctions.items()) + list(timeFunctions.items())) + +allPythonFunctions = dict(list(allStaticFunctions.items()) + list(allMutatingFunctions.items())) + +safeLibraryMap = { "string" : [ "ascii_letters", "ascii_lowercase", "ascii_uppercase", + "digits", "hexdigits", "letters", "lowercase", "octdigits", + "punctuation", "printable", "uppercase", "whitespace", + "capitalize", "expandtabs", "find", "rfind", "count", + "lower", "split", "rsplit", "splitfields", "join", + "joinfields", "lstrip", "rstrip", "strip", "swapcase", + "upper", "ljust", "rjust", "center", "zfill", "replace"], + "math" : [ "ceil", "copysign", "fabs", "floor", "fmod", + "frexp", "fsum", "isinf", "isnan", "ldexp", "modf", "trunc", "exp", + "expm1", "log", "log1p", "log10", "sqrt", "acos", "asin", + "atan", "atan2", "cos", "hypot", "sin", "tan", "degrees", "radians", + "acosh", "asinh", "atanh", "cosh", "sinh", "tanh", "erf", "erfc", + "gamma", "lgamma", "pi", "e" ], + "random" : [ ], + "__future__" : ["nested_scopes", "generators", "division", "absolute_import", + "with_statement", "print_function", "unicode_literals"] } + +libraryMap = { "string" : [ "ascii_letters", "ascii_lowercase", "ascii_uppercase", + "digits", "hexdigits", "letters", "lowercase", "octdigits", + "punctuation", "printable", "uppercase", "whitespace", + "capwords", "maketrans", "atof", "atoi", "atol", "capitalize", + "expandtabs", "find", "rfind", "index", "rindex", "count", + "lower", "split", "rsplit", "splitfields", "join", + "joinfields", "lstrip", "rstrip", "strip", "swapcase", + "translate", "upper", "ljust", "rjust", "center", "zfill", + "replace", "Template", "Formatter" ], + "math" : [ "ceil", "copysign", "fabs", "factorial", "floor", "fmod", + "frexp", "fsum", "isinf", "isnan", "ldexp", "modf", "trunc", "exp", + "expm1", "log", "log1p", "log10", "pow", "sqrt", "acos", "asin", + "atan", "atan2", "cos", "hypot", "sin", "tan", "degrees", "radians", + "acosh", "asinh", "atanh", "cosh", "sinh", "tanh", "erf", "erfc", + "gamma", "lgamma", "pi", "e" ], + "random" : ["seed", "getstate", "setstate", "jumpahead", "getrandbits", + "randrange", "randrange", "randint", "choice", "shuffle", "sample", + "random", "uniform", "triangular", "betavariate", "expovariate", + "gammavariate", "gauss", "lognormvariate", "normalvariate", + "vonmisesvariate", "paretovariate", "weibullvariate", "WichmannHill", + "whseed", "SystemRandom" ], + "__future__" : ["nested_scopes", "generators", "division", "absolute_import", + "with_statement", "print_function", "unicode_literals"], + "copy" : ["copy", "deepcopy"] } + +libraryDictMap = { "string" : builtInStringFunctions, + "math" : mathFunctions, + "random" : randomFunctions, + "__future__" : futureFunctions, + "copy" : copyFunctions } + +typeMethodMap = {"string" : ["capitalize", "center", "count", "decode", "encode", "endswith", + "expandtabs", "find", "format", "index", "isalnum", "isalpha", + "isdigit", "islower", "isspace", "istitle", "isupper", "join", + "ljust", "lower", "lstrip", "partition", "replace", "rfind", + "rindex", "rjust", "rpartition", "rsplit", "rstrip", "split", + "splitlines", "startswith", "strip", "swapcase", "title", + "translate", "upper", "zfill"], + "list" : [ "append", "extend", "count", "index", "insert", "pop", "remove", + "reverse", "sort"], + "set" : [ "isdisjoint", "issubset", "issuperset", "union", "intersection", + "difference", "symmetric_difference", "update", "intersection_update", + "difference_update", "symmetric_difference_update", "add", + "remove", "discard", "pop", "clear"], + "dict" : [ "iter", "clear", "copy", "fromkeys", "get", "has_key", "items", + "iteritems", "iterkeys", "itervalues", "keys", "pop", "popitem", + "setdefault", "update", "values", "viewitems", "viewkeys", + "viewvalues"] } + +astNames = { + ast.Module : "Module", ast.Interactive : "Interactive Module", + ast.Expression : "Expression Module", ast.Suite : "Suite", + + ast.FunctionDef : "Function Definition", + ast.ClassDef : "Class Definition", ast.Return : "Return", + ast.Delete : "Delete", ast.Assign : "Assign", + ast.AugAssign : "AugAssign", ast.For : "For", + ast.While : "While", ast.If : "If", ast.With : "With", + ast.Raise : "Raise", + ast.Try : "Try", ast.Assert : "Assert", + ast.Import : "Import", ast.ImportFrom : "Import From", + ast.Global : "Global", ast.Expr : "Expression", + ast.Pass : "Pass", ast.Break : "Break", ast.Continue : "Continue", + + ast.BoolOp : "Boolean Operation", ast.BinOp : "Binary Operation", + ast.UnaryOp : "Unary Operation", ast.Lambda : "Lambda", + ast.IfExp : "Ternary", ast.Dict : "Dictionary", ast.Set : "Set", + ast.ListComp : "List Comprehension", ast.SetComp : "Set Comprehension", + ast.DictComp : "Dict Comprehension", + ast.GeneratorExp : "Generator", ast.Yield : "Yield", + ast.Compare : "Compare", ast.Call : "Call", + ast.Num : "Number", ast.Str : "String", ast.Bytes : "Bytes", + ast.NameConstant : "Name Constant", + ast.Attribute : "Attribute", + ast.Subscript : "Subscript", ast.Name : "Name", ast.List : "List", + ast.Tuple : "Tuple", ast.Starred : "Starred", + + ast.Load : "Load", ast.Store : "Store", ast.Del : "Delete", + ast.AugLoad : "AugLoad", ast.AugStore : "AugStore", + ast.Param : "Parameter", + + ast.Ellipsis : "Ellipsis", ast.Slice : "Slice", + ast.ExtSlice : "ExtSlice", ast.Index : "Index", + + ast.And : "And", ast.Or : "Or", ast.Add : "Add", ast.Sub : "Subtract", + ast.Mult : "Multiply", ast.Div : "Divide", ast.Mod : "Modulo", + ast.Pow : "Power", ast.LShift : "Left Shift", + ast.RShift : "Right Shift", ast.BitOr : "|", ast.BitXor : "^", + ast.BitAnd : "&", ast.FloorDiv : "Integer Divide", + ast.Invert : "Invert", ast.Not : "Not", ast.UAdd : "Unsigned Add", + ast.USub : "Unsigned Subtract", ast.Eq : "==", ast.NotEq : "!=", + ast.Lt : "<", ast.LtE : "<=", ast.Gt : ">", ast.GtE : ">=", + ast.Is : "Is", ast.IsNot : "Is Not", ast.In : "In", + ast.NotIn : "Not In", + + ast.comprehension: "Comprehension", + ast.ExceptHandler : "Except Handler", ast.arguments : "Arguments", ast.arg : "Argument", + ast.keyword : "Keyword", ast.alias : "Alias", ast.withitem : "With item" + }
\ No newline at end of file diff --git a/canonicalize/tools.py b/canonicalize/tools.py new file mode 100644 index 0000000..1df653a --- /dev/null +++ b/canonicalize/tools.py @@ -0,0 +1,15 @@ +"""This is a file of useful functions used throughout the hint generation program""" +import time, os.path, ast, json + +def log(msg, filename="main", newline=True): + return + txt = "" + if newline: + t = time.strftime("%d %b %Y %H:%M:%S") + txt += t + "\t" + txt += msg + if newline: + txt += "\n" + f = open('log/' + filename + ".log", "a") + f.write(txt) + f.close() diff --git a/canonicalize/transformations.py b/canonicalize/transformations.py new file mode 100644 index 0000000..7ae1e40 --- /dev/null +++ b/canonicalize/transformations.py @@ -0,0 +1,2775 @@ +import ast, copy, functools + +from .tools import log +from .namesets import * +from .astTools import * + +### VARIABLE ANONYMIZATION ### + +def updateVariableNames(a, varMap, scopeName, randomCounter, imports): + if not isinstance(a, ast.AST): + return + + if type(a) in [ast.FunctionDef, ast.ClassDef]: + if a.name in varMap: + if not hasattr(a, "originalId"): + a.originalId = a.name + a.name = varMap[a.name] + anonymizeStatementNames(a, varMap, "_" + a.name, imports) + elif type(a) == ast.arg: + if a.arg not in varMap and not (builtInName(a.arg) or importedName(a.arg, imports)): + log("Can't assign to arg?", "bug") + if a.arg in varMap: + if not hasattr(a, "originalId"): + a.originalId = a.arg + if varMap[a.arg][0] == "r": + a.randomVar = True # so we know it can crash + if a.arg == varMap[a.arg]: + # Check whether this is a given name + if not isAnonVariable(varMap[a.arg]): + a.dontChangeName = True + a.arg = varMap[a.arg] + elif type(a) == ast.Name: + if a.id not in varMap and not (builtInName(a.id) or importedName(a.id, imports)): + varMap[a.id] = "r" + str(randomCounter[0]) + scopeName + randomCounter[0] += 1 + if a.id in varMap: + if not hasattr(a, "originalId"): + a.originalId = a.id + if varMap[a.id][0] == "r": + a.randomVar = True # so we know it can crash + if a.id == varMap[a.id]: + # Check whether this is a given name + if not isAnonVariable(varMap[a.id]): + a.dontChangeName = True + a.id = varMap[a.id] + else: + for child in ast.iter_child_nodes(a): + updateVariableNames(child, varMap, scopeName, randomCounter, imports) + +def gatherLocalScope(a, globalMap, scopeName, imports, goBackwards=False): + localMap = { } + varLetter = "g" if type(a) == ast.Module else "v" + paramCounter = 0 + localCounter = 0 + if type(a) == ast.FunctionDef: + for param in a.args.args: + if type(param) == ast.arg: + if not (builtInName(param.arg) or importedName(param.arg, imports)) and param.arg not in localMap and param.arg not in globalMap: + localMap[param.arg] = "p" + str(paramCounter) + scopeName + paramCounter += 1 + else: + log("transformations\tanonymizeFunctionNames\tWeird parameter type: " + str(type(param)), "bug") + if goBackwards: + items = a.body[::-1] # go through this backwards to get the used function names + else: + items = a.body[:] + # First, go through and create the globalMap + while len(items) > 0: + item = items[0] + if type(item) in [ast.FunctionDef, ast.ClassDef]: + if not (builtInName(item.name) or importedName(item.name, imports)): + if item.name not in localMap and item.name not in globalMap: + localMap[item.name] = "helper_" + varLetter + str(localCounter) + scopeName + localCounter += 1 + else: + item.dontChangeName = True + + # If there are any variables in this node, find and label them + if type(item) in [ ast.Assign, ast.AugAssign, ast.For, ast.With, + ast.Lambda, ast.comprehension, ast.ExceptHandler]: + assns = [] + if type(item) == ast.Assign: + assns = item.targets + elif type(item) == ast.AugAssign: + assns = [item.target] + elif type(item) == ast.For: + assns = [item.target] + elif type(item) == ast.With: + if item.optional_vars != None: + assns = [item.optional_vars] + elif type(item) == ast.Lambda: + assns = item.args.args + elif type(item) == ast.comprehension: + assns = [item.target] + elif type(item) == ast.ExceptHandler: + if item.name != None: + assns = [item.name] + + for assn in assns: + if type(assn) == ast.Name: + if not (builtInName(assn.id) or importedName(assn.id, imports)): + if assn.id not in localMap and assn.id not in globalMap: + localMap[assn.id] = varLetter + str(localCounter) + scopeName + localCounter += 1 + + if type(item) in [ast.For, ast.While, ast.If]: + items += item.body + item.orelse + elif type(item) in [ast.With, ast.ExceptHandler]: + items += item.body + elif type(item) == ast.Try: + items += item.body + item.handlers + item.orelse + item.finalbody + items = items[1:] + return localMap + +def anonymizeStatementNames(a, globalMap, scopeName, imports, goBackwards=False): + """Gather the local variables, then update variable names in each line""" + localMap = gatherLocalScope(a, globalMap, scopeName, imports, goBackwards=goBackwards) + varMap = { } + varMap.update(globalMap) + varMap.update(localMap) + randomCounter = [0] + functionsSeen = [] + if type(a) == ast.FunctionDef: + for arg in a.args.args: + updateVariableNames(arg, varMap, scopeName, randomCounter, imports) + for line in a.body: + updateVariableNames(line, varMap, scopeName, randomCounter, imports) + +def anonymizeNames(a, namesToKeep, imports): + """Anonymize all of variables/names that occur in the given AST""" + """If we run this on an anonymized AST, it will fix the names again to get rid of any gaps!""" + if type(a) != ast.Module: + return a + globalMap = { } + for var in namesToKeep: + globalMap[var] = var + anonymizeStatementNames(a, globalMap, "", imports, goBackwards=True) + return a + +def propogateNameMetadata(a, namesToKeep, imports): + """Propogates name metadata through a state. We assume that the names are all properly formatted""" + if type(a) == list: + for child in a: + child = propogateNameMetadata(child, namesToKeep, imports) + return a + elif not isinstance(a, ast.AST): + return a + if type(a) == ast.Name: + if (builtInName(a.id) or importedName(a.id, imports)): + pass + elif a.id in namesToKeep: + a.dontChangeName = True + else: + if not hasattr(a, "originalId"): + a.originalId = a.id + if not isAnonVariable(a.id): + a.dontChangeName = True # it's a name we shouldn't mess with + elif type(a) == ast.arg: + if (builtInName(a.arg) or importedName(a.arg, imports)): + pass + elif a.arg in namesToKeep: + a.dontChangeName = True + else: + if not hasattr(a, "originalId"): + a.originalId = a.arg + if not isAnonVariable(a.arg): + a.dontChangeName = True # it's a name we shouldn't mess with + for child in ast.iter_child_nodes(a): + child = propogateNameMetadata(child, namesToKeep, imports) + return a + +### HELPER FOLDING ### + +def findHelperFunction(a, helperId, helperCount): + """Finds the first helper function used in the ast""" + if not isinstance(a, ast.AST): + return None + + # Check all the children, so that we don't end up with a recursive problem + for child in ast.iter_child_nodes(a): + f = findHelperFunction(child, helperId, helperCount) + if f != None: + return f + # Then check if this is the right call + if type(a) == ast.Call: + if type(a.func) == ast.Name and a.func.id == helperId: + if helperCount[0] > 0: + helperCount[0] -= 1 + else: + return a + return None + +def individualizeVariables(a, variablePairs, idNum, imports): + """Replace variable names with new individualized ones (for inlining methods)""" + if not isinstance(a, ast.AST): + return + + if type(a) == ast.Name: + if a.id not in variablePairs and not (builtInName(a.id) or importedName(a.id, imports)): + name = "_var_" + a.id + "_" + str(idNum[0]) + variablePairs[a.id] = name + if a.id in variablePairs: + a.id = variablePairs[a.id] + # Override built-in names when they're assigned to. + elif type(a) == ast.Assign and type(a.targets[0]) == ast.Name: + if a.targets[0].id not in variablePairs: + name = "_var_" + a.targets[0].id + "_" + str(idNum[0]) + variablePairs[a.targets[0].id] = name + elif type(a) == ast.arguments: + for arg in a.args: + if type(arg) == ast.arg: + name = "_arg_" + arg.arg + "_" + str(idNum[0]) + variablePairs[arg.arg] = name + arg.arg = variablePairs[arg.arg] + return + elif type(a) == ast.Call: + if type(a.func) == ast.Name: + variablePairs[a.func.id] = a.func.id # save the function name! + + for child in ast.iter_child_nodes(a): + individualizeVariables(child, variablePairs, idNum, imports) + +def replaceAst(a, old, new, shouldReplace): + """Replace the old value with the new one, but only once! That's what the shouldReplace variable is for.""" + if not isinstance(a, ast.AST) or not shouldReplace[0]: + return a + elif compareASTs(a, old, checkEquality=True) == 0: + shouldReplace[0] = False + return new + return applyToChildren(a, lambda x: replaceAst(x, old, new, shouldReplace)) + +def mapHelper(a, helper, idNum, imports): + """Map the helper function into this function, if it's used. idNum gives us the current id""" + if type(a) in [ast.FunctionDef, ast.ClassDef]: + a.body = mapHelper(a.body, helper, idNum, imports) + elif type(a) in [ast.For, ast.While, ast.If]: + a.body = mapHelper(a.body, helper, idNum, imports) + a.orelse = mapHelper(a.orelse, helper, idNum, imports) + if type(a) != list: + return a # we only deal with lists + + i = 0 + body = a + while i < len(body): + if type(body[i]) in [ast.FunctionDef, ast.ClassDef, ast.For, ast.While, ast.If]: + body[i] = mapHelper(body[i], helper, idNum, imports) # deal with blocks first + # While we still need to replace a function being called + helperCount = 0 + while countVariables(body[i], helper.name) > helperCount: + callExpression = findHelperFunction(body[i], helper.name, [helperCount]) + if callExpression == None: + break + idNum[0] += 1 + variablePairs = {} + + # First, update the method's variables + methodArgs = deepcopy(helper.args) + individualizeVariables(methodArgs, variablePairs, idNum, imports) + methodLines = deepcopyList(helper.body) + for j in range(len(methodLines)): + individualizeVariables(methodLines[j], variablePairs, idNum, imports) + + # Then, move each of the parameters into an assignment + argLines = [] + badArgs = False + for j in range(len(callExpression.args)): + arg = methodArgs.args[j] + value = callExpression.args[j] + # If it only occurs once, and in the return statement.. + if countVariables(methodLines, arg.arg) == 1 and countVariables(methodLines[-1], arg.arg) == 1: + var = ast.Name(arg.arg, ast.Load()) + transferMetaData(arg, var) + methodLines[-1] = replaceAst(methodLines[-1], var, value, [True]) + elif couldCrash(value): + badArgs = True + break + else: + if type(arg) == ast.arg: + var = ast.Name(arg.arg, ast.Store(), helperVar=True) + transferMetaData(arg, var) + varAssign = ast.Assign([var], value, global_id=arg.global_id, helperParamAssign=True) + argLines.append(varAssign) + else: # what other type would it be? + log("transformations\tmapHelper\tWeird arg: " + str(type(arg)), "bug") + if badArgs: + helperCount += 1 + continue + + # Finally, carry over the final value into the correct position + if countOccurances(helper, ast.Return) == 1: + returnLine = [replaceAst(body[i], callExpression, methodLines[-1].value, [True])] + methodLines.pop(-1) + else: + comp = compareASTs(body[i].value, callExpression, checkEquality=True) == 0 + if type(body[i]) == ast.Return and comp: + # In the case of a return, we can return None, it's cool + callVal = ast.NameConstant(None, helperReturnVal=True) + transferMetaData(callExpression, callVal) + returnVal = ast.Return(callVal, global_id=body[i].global_id, helperReturnAssign=True) + returnLine = [returnVal] + elif type(body[i]) == ast.Assign and comp: + # Assigns can be assigned to None + callVal = ast.NameConstant(None, helperReturnVal=True) + transferMetaData(callExpression, callVal) + assignVal = ast.Assign(body[i].targets, callVal, global_id=body[i].global_id, helperReturnAssign=True) + returnLine = [assignVal] + elif type(body[i]) == ast.Expr and comp: + # Delete the line if it's by itself + returnLine = [] + else: + # Otherwise, what is going on?!? + returnLine = [] + log("transformations\tmapHelper\tWeird removeCall: " + str(type(body[i])), "bug") + #log("transformations\tmapHelper\tMapped helper function: " + helper.name, "bug") + body[i:i+1] = argLines + methodLines + returnLine + i += 1 + return a + +def mapVariable(a, varId, assn): + """Map the variable assignment into the function, if it's needed""" + if type(a) != ast.FunctionDef: + return a + + for arg in a.args.args: + if arg.arg == varId: + return a # overriden by local variable + for i in range(len(a.body)): + line = a.body[i] + if type(line) == ast.Assign: + for target in line.targets: + if type(target) == ast.Name and target.id == varId: + break + elif type(target) in [ast.Tuple, ast.List]: + for elt in target.elts: + if type(elt) == ast.Name and elt.id == varId: + break + if countVariables(line, varId) > 0: + a.body[i:i+1] = [deepcopy(assn), line] + break + return a + +def helperFolding(a, mainFun, imports): + """When possible, fold the functions used in this module into their callers""" + if type(a) != ast.Module: + return a + + globalCounter = [0] + body = a.body + i = 0 + while i < len(body): + item = body[i] + if type(item) == ast.FunctionDef: + if item.name != mainFun: + # We only want non-recursive functions that have a single return which occurs at the end + # Also, we don't want any functions with parameters that are changed during the function + returnOccurs = countOccurances(item, ast.Return) + if countVariables(item, item.name) == 0 and returnOccurs <= 1 and type(item.body[-1]) == ast.Return: + allArgs = [] + for arg in item.args.args: + allArgs.append(arg.arg) + for tmpA in ast.walk(item): + if type(tmpA) in [ast.Assign, ast.AugAssign]: + allGood = True + assignedIds = gatherAssignedVarIds(tmpA.targets if type(tmpA) == ast.Assign else [tmpA.target]) + for arg in allArgs: + if arg in assignedIds: + allGood = False + break + if not allGood: + break + else: + for j in range(len(item.body)-1): + if couldCrash(item.body[j]): + break + else: + # If we satisfy these requirements, translate the body of the function into all functions that call it + gone = True + used = False + for j in range(len(body)): + if i != j and type(body[j]) == ast.FunctionDef: + if countVariables(body[j], item.name) > 0: + used = True + mapHelper(body[j], item, globalCounter, imports) + if countVariables(body[j], item.name) > 0: + gone = False + if used and gone: + body.pop(i) + continue + elif type(item) == ast.Assign: + # Is it ever changed in the global area? + if len(item.targets) == 1 and type(item.targets[0]) == ast.Name: + if eventualType(item.value) in [int, float, bool, str]: # if it isn't mutable + for j in range(i+1, len(body)): + line = body[j] + if type(line) == ast.FunctionDef: + if countOccurances(line, ast.Global) > 0: # TODO: improve this + break + if item.targets[0].id in getAllAssignedVarIds(line): + break + else: # if in scope + if countVariables(line, item.targets[0].id) > 0: + break + else: + # Variable never appears again- we can map it! + for j in range(i+1, len(body)): + line = body[j] + if type(line) == ast.FunctionDef: + mapVariable(line, item.targets[0].id, item) + body.pop(i) + continue + i += 1 + return a + +### AST PREPARATION ### + +def listNotEmpty(a): + """Determines that the iterable is NOT empty, if we can know that""" + """Used for For objects""" + if not isinstance(a, ast.AST): + return False + if type(a) == ast.Call: + if type(a.func) == ast.Name and a.func.id in ["range"]: + if len(a.args) == 1: # range(x) + return type(a.args[0]) == ast.Num and type(a.args[0].n) != complex and a.args[0].n > 0 + elif len(a.args) == 2: # range(start, x) + if type(a.args[0]) == ast.Num and type(a.args[1]) == ast.Num and \ + type(a.args[0].n) != complex and type(a.args[1].n) != complex and \ + a.args[0].n < a.args[1].n: + return True + elif type(a.args[1]) == ast.BinOp and type(a.args[1].op) == ast.Add: + if type(a.args[1].right) == ast.Num and type(a.args[1].right) != complex and a.args[1].right.n > 0 and \ + compareASTs(a.args[0], a.args[1].left, checkEquality=True) == 0: + return True + elif type(a.args[1].left) == ast.Num and type(a.args[1].left) != complex and a.args[1].left.n > 0 and \ + compareASTs(a.args[0], a.args[1].right, checkEquality=True) == 0: + return True + elif type(a) in [ast.List, ast.Tuple]: + return len(a.elts) > 0 + elif type(a) == ast.Str: + return len(a.s) > 0 + return False + +def simplifyUpdateId(var, variableMap, idNum): + """Update the varID of a new variable""" + if type(var) not in [ast.Name, ast.arg]: + return var + idVar = var.id if type(var) == ast.Name else var.arg + if not hasattr(var, "varID"): + if idVar in variableMap: + var.varID = variableMap[idVar][1] + else: + var.varID = idNum[0] + idNum[0] += 1 + +def simplify_multicomp(a): + if type(a) == ast.Compare and len(a.ops) > 1: + # Only do one comparator at a time. If we don't do this, things get messy! + comps = [a.left] + a.comparators + values = [ ] + # Compare each of the pairs + for i in range(len(a.ops)): + if i > 0: + # Label all nodes as middle parts so we can recognize them later + assignPropertyToAll(comps[i], "multiCompMiddle") + values.append(ast.Compare(comps[i], [a.ops[i]], [deepcopy(comps[i+1])], multiCompPart=True)) + # Combine comparisons with and operators + boolOp = ast.And(multiCompOp=True) + boolopVal = ast.BoolOp(boolOp, values, multiComp=True, global_id=a.global_id) + return boolopVal + return a + +def simplify(a): + """This function simplifies the usual Python AST to make it usable by our functions.""" + if not isinstance(a, ast.AST): + return a + elif type(a) == ast.Assign: + if len(a.targets) > 1: + # Go through all targets and assign them on separate lines + lines = [ast.Assign([a.targets[-1]], a.value, global_id=a.global_id)] + for i in range(len(a.targets)-1, 0, -1): + t = a.targets[i] + if type(t) == ast.Name: + loadedTarget = ast.Name(t.id, ast.Load()) + elif type(t) == ast.Subscript: + loadedTarget = ast.Subscript(deepcopy(t.value), deepcopy(t.slice), ast.Load()) + elif type(t) == ast.Attribute: + loadedTarget = ast.Attribute(deepcopy(t.value), t.attr, ast.Load()) + elif type(t) == ast.Tuple: + loadedTarget = ast.Tuple(deepcopy(t.elts), ast.Load()) + elif type(t) == ast.List: + loadedTarget = ast.List(deepcopy(t.elts), ast.Load()) + else: + log("transformations\tsimplify\tOdd loadedTarget: " + str(type(t)), "bug") + transferMetaData(t, loadedTarget) + loadedTarget.global_id = t.global_id + + lines.append(ast.Assign([a.targets[i-1]], loadedTarget, global_id=a.global_id)) + else: + lines = [a] + + i = 0 + while i < len(lines): + # For each line, figure out type and varID + lines[i].value = simplify(lines[i].value) + t = lines[i].targets[0] + if type(t) in [ast.Tuple, ast.List]: + val = lines[i].value + # If the items are being assigned separately, with no dependance on each other, + # separate out the elements of the tuple + if type(val) in [ast.Tuple, ast.List] and len(t.elts) == len(val.elts) and \ + allVariableNamesUsed(val) == []: + listLines = [] + for j in range(len(t.elts)): + assignVal = ast.Assign([t.elts[j]], val.elts[j], global_id = lines[i].global_id) + listLines += simplify(assignVal) + lines[i:i+1] = listLines + i += len(listLines) - 1 + i += 1 + return lines + elif type(a) == ast.AugAssign: + # Turn all AugAssigns into Assigns + a.target = simplify(a.target) + if eventualType(a.target) not in [bool, int, str, float]: + # Can't get rid of AugAssign, in case the += is different + a.value = simplify(a.value) + return a + if type(a.target) == ast.Name: + loadedTarget = ast.Name(a.target.id, ast.Load()) + elif type(a.target) == ast.Subscript: + loadedTarget = ast.Subscript(deepcopy(a.target.value), deepcopy(a.target.slice), ast.Load()) + elif type(a.target) == ast.Attribute: + loadedTarget = ast.Attribute(deepcopy(a.target.value), a.target.attr, ast.Load()) + elif type(a.target) == ast.Tuple: + loadedTarget = ast.Tuple(deepcopy(a.target.elts), ast.Load()) + elif type(a.target) == ast.List: + loadedTarget = ast.List(deepcopy(a.target.elts), ast.Load()) + else: + log("transformations\tsimplify\tOdd AugAssign target: " + str(type(a.target)), "bug") + transferMetaData(a.target, loadedTarget) + loadedTarget.global_id = a.target.global_id + a.target.augAssignVal = True # for later recognition + loadedTarget.augAssignVal = True + assignVal = ast.Assign([a.target], ast.BinOp(loadedTarget, a.op, a.value, augAssignBinOp=True), global_id=a.global_id) + return simplify(assignVal) + elif type(a) == ast.Compare and len(a.ops) > 1: + return simplify(simplify_multicomp(a)) + return applyToChildren(a, lambda x : simplify(x)) + +def propogateMetadata(a, argTypes, variableMap, idNum): + """This function propogates metadata about type throughout the AST. + argTypes lets us institute types for each function's args + variableMap maps variable ids to their types and id numbers + idNum gives us a global number to use for variable ids""" + if not isinstance(a, ast.AST): + return a + elif type(a) == ast.FunctionDef: + if a.name in argTypes: + theseTypes = argTypes[a.name] + else: + theseTypes = [] + variableMap = copy.deepcopy(variableMap) # variables shouldn't affect variables outside + idNum = copy.deepcopy(idNum) + for i in range(len(a.args.args)): + arg = a.args.args[i] + if type(arg) == ast.arg: + simplifyUpdateId(arg, variableMap, idNum) + # Match the args if possible + if len(a.args.args) == len(theseTypes) and \ + theseTypes[i] in ["int", "str", "float", "bool", "list"]: + arg.type = eval(theseTypes[i]) + variableMap[arg.arg] = (arg.type, arg.varID) + else: + arg.type = None + variableMap[arg.arg] = (None, arg.varID) + else: + log("transformations\tpropogateMetadata\tWeird type in args: " + str(type(arg)), "bug") + + newBody = [] + for line in a.body: + newLine = propogateMetadata(line, argTypes, variableMap, idNum) + if type(newLine) == list: + newBody += newLine + else: + newBody.append(newLine) + a.body = newBody + return a + elif type(a) == ast.Assign: + val = a.value = propogateMetadata(a.value, argTypes, variableMap, idNum) + if len(a.targets) == 1: + t = a.targets[0] + if type(t) == ast.Name: + simplifyUpdateId(t, variableMap, idNum) + varType = eventualType(val) + t.type = varType + variableMap[t.id] = (varType, t.varID) # update type + elif type(t) in [ast.Tuple, ast.List]: + # If the items are being assigned separately, with no dependance on each other, + # assign the appropriate types + getTypes = type(val) in [ast.Tuple, ast.List] and len(t.elts) == len(val.elts) and len(allVariableNamesUsed(val)) == 0 + for j in range(len(t.elts)): + t2 = t.elts[j] + givenType = eventualType(val.elts[j]) if getTypes else None + if type(t2) == ast.Name: + simplifyUpdateId(t2, variableMap, idNum) + t2.type = givenType + variableMap[t2.id] = (givenType, t2.varID) + elif type(t.elts[j]) in [ast.Subscript, ast.Attribute, ast.Tuple, ast.List]: + pass + else: + log("transformations\tpropogateMetadata\tOdd listTarget: " + str(type(t.elts[j])), "bug") + return a + elif type(a) == ast.AugAssign: + # Turn all AugAssigns into Assigns + t = a.target = propogateMetadata(a.target, argTypes, variableMap, idNum) + a.value = propogateMetadata(a.value, argTypes, variableMap, idNum) + if type(t) == ast.Name: + if eventualType(a.target) not in [bool, int, str, float]: + finalType = None + else: + if type(a.target) == ast.Name: + loadedTarget = ast.Name(a.target.id, ast.Load()) + elif type(a.target) == ast.Subscript: + loadedTarget = ast.Subscript(deepcopy(a.target.value), deepcopy(a.target.slice), ast.Load()) + elif type(a.target) == ast.Attribute: + loadedTarget = ast.Attribute(deepcopy(a.target.value), a.target.attr, ast.Load()) + elif type(a.target) == ast.Tuple: + loadedTarget = ast.Tuple(deepcopy(a.target.elts), ast.Load()) + elif type(a.target) == ast.List: + loadedTarget = ast.List(deepcopy(a.target.elts), ast.Load()) + else: + log("transformations\tsimplify\tOdd AugAssign target: " + str(type(a.target)), "bug") + transferMetaData(a.target, loadedTarget) + actualValue = ast.BinOp(loadedTarget, a.op, a.value) + finalType = eventualType(actualValue) + + simplifyUpdateId(t, variableMap, idNum) + varType = finalType + t.type = varType + variableMap[t.id] = (varType, t.varID) # update type + return a + elif type(a) == ast.For: # START HERE + a.iter = propogateMetadata(a.iter, argTypes, variableMap, idNum) + if type(a.target) == ast.Name: + simplifyUpdateId(a.target, variableMap, idNum) + ev = eventualType(a.iter) + if ev == str: + a.target.type = str + # we know ranges are made of ints + elif type(a.iter) == ast.Call and type(a.iter.func) == ast.Name and (a.iter.func.id) == "range": + a.target.type = int + else: + a.target.type = None + variableMap[a.target.id] = (a.target.type, a.target.varID) + a.target = propogateMetadata(a.target, argTypes, variableMap, idNum) + + # Go through the body and orelse to map out more variables + body = [] + bodyMap = copy.deepcopy(variableMap) + for i in range(len(a.body)): + result = propogateMetadata(a.body[i], argTypes, bodyMap, idNum) + if type(result) == list: + body += result + else: + body.append(result) + a.body = body + + orelse = [] + orelseMap = copy.deepcopy(variableMap) + for var in bodyMap: + if var not in orelseMap: + orelseMap[var] = (42, bodyMap[var][1]) + for i in range(len(a.orelse)): + result = propogateMetadata(a.orelse[i], argTypes, orelseMap, idNum) + if type(result) == list: + orelse += result + else: + orelse.append(result) + a.orelse = orelse + + keys = list(variableMap.keys()) + for key in keys: # reset types of changed keys + if key not in bodyMap or bodyMap[key] != variableMap[key] or \ + key not in orelseMap or orelseMap[key] != variableMap[key]: + variableMap[key] = (None, variableMap[key][1]) + if countOccurances(a.body, ast.Break) == 0: # We will definitely enter the else! + for key in orelseMap: + # If we KNOW it will be this type + if key not in bodyMap or bodyMap[key] == orelseMap[key]: + variableMap[key] = orelseMap[key] + + # If we KNOW it will run at least once + if listNotEmpty(a.iter): + for key in bodyMap: + if key in variableMap: + continue + # type might be changed + elif key in orelseMap and orelseMap[key] != bodyMap[key]: + continue + variableMap[key] = bodyMap[key] + return a + elif type(a) == ast.While: + body = [] + bodyMap = copy.deepcopy(variableMap) + for i in range(len(a.body)): + result = propogateMetadata(a.body[i], argTypes, bodyMap, idNum) + if type(result) == list: + body += result + else: + body.append(result) + a.body = body + + orelse = [] + orelseMap = copy.deepcopy(variableMap) + for var in bodyMap: + if var not in orelseMap: + orelseMap[var] = (None, bodyMap[var][1]) + for i in range(len(a.orelse)): + result = propogateMetadata(a.orelse[i], argTypes, orelseMap, idNum) + if type(result) == list: + orelse += result + else: + orelse.append(result) + a.orelse = orelse + + keys = list(variableMap.keys()) + for key in keys: + if key not in bodyMap or bodyMap[key] != variableMap[key] or \ + key not in orelseMap or orelseMap[key] != variableMap[key]: + variableMap[key] = (None, variableMap[key][1]) + + a.test = propogateMetadata(a.test, argTypes, variableMap, idNum) + return a + elif type(a) == ast.If: + a.test = propogateMetadata(a.test, argTypes, variableMap, idNum) + variableMap1 = copy.deepcopy(variableMap) + variableMap2 = copy.deepcopy(variableMap) + + body = [] + for i in range(len(a.body)): + result = propogateMetadata(a.body[i], argTypes, variableMap1, idNum) + if type(result) == list: + body += result + else: + body.append(result) + a.body = body + + for var in variableMap1: + if var not in variableMap2: + variableMap2[var] = (42, variableMap1[var][1]) + + orelse = [] + for i in range(len(a.orelse)): + result = propogateMetadata(a.orelse[i], argTypes, variableMap2, idNum) + if type(result) == list: + orelse += result + else: + orelse.append(result) + a.orelse = orelse + + variableMap.clear() + for key in variableMap1: + if key in variableMap2: + if variableMap1[key] == variableMap2[key]: + variableMap[key] = variableMap1[key] + elif variableMap1[key][1] != variableMap2[key][1]: + log("transformations\tsimplify\tvarId mismatch", "bug") + else: # type mismatch + if variableMap2[key][0] == 42: + variableMap[key] = variableMap1[key] + else: + variableMap[key] = (None, variableMap1[key][1]) + elif len(a.orelse) > 0 and type(a.orelse[-1]) == ast.Return: # if the else exits, it doesn't matter + variableMap[key] = variableMap1[key] + for key in variableMap2: + if key not in variableMap1 and len(a.body) > 0 and type(a.body[-1]) == ast.Return: # if the if exits, it doesn't matter + variableMap[key] = variableMap2[key] + return a + elif type(a) == ast.Name: + if a.id in variableMap and variableMap[a.id][0] != 42: + a.type = variableMap[a.id][0] + a.varID = variableMap[a.id][1] + else: + if a.id in variableMap: + a.varID = variableMap[a.id][1] + a.ctx = propogateMetadata(a.ctx, argTypes, variableMap, idNum) + return a + elif type(a) == ast.AugLoad: + return ast.Load() + elif type(a) == ast.AugStore: + return ast.Store() + return applyToChildren(a, lambda x : propogateMetadata(x, argTypes, variableMap, idNum)) + +### SIMPLIFYING FUNCTIONS ### + +def applyTransferLambda(x): + """Simplify an expression by applying constant folding, re-formatting to an AST, and then tranferring the metadata appropriately.""" + if x == None: + return x + tmp = astFormat(constantFolding(x)) + if hasattr(tmp, "global_id") and hasattr(x, "global_id") and tmp.global_id != x.global_id: + return tmp # don't do the transfer, this already has its own metadata + else: + transferMetaData(x, tmp) + return tmp + +def constantFolding(a): + """In constant folding, we evaluate all constant expressions instead of doing operations at runtime""" + if not isinstance(a, ast.AST): + return a + t = type(a) + if t in [ast.FunctionDef, ast.ClassDef]: + for i in range(len(a.body)): + a.body[i] = applyTransferLambda(a.body[i]) + return a + elif t in [ast.Import, ast.ImportFrom, ast.Global]: + return a + elif t == ast.BoolOp: + # Condense the boolean's values + newValues = [] + ranks = [] + count = 0 + for val in a.values: + # Condense the boolean operations into one line, if possible + c = constantFolding(val) + if type(c) == ast.BoolOp and type(c.op) == type(a.op) and not hasattr(c, "multiComp"): + newValues += c.values + ranks.append(range(count,count+len(c.values))) + count += len(c.values) + else: + newValues.append(c) + ranks.append(count) + count += 1 + + # Or breaks with True, And breaks with False + breaks = (type(a.op) == ast.Or) + + # Remove the opposite values IF removing them won't mess up the type. + i = len(newValues) - 1 + while i > 0: + if (newValues[i] == (not breaks)) and eventualType(newValues[i-1]) == bool: + newValues.pop(i) + i -= 1 + + if len(newValues) == 0: + # There's nothing to evaluate + return (not breaks) + elif len(newValues) == 1: + # If we're down to one value, just return it! + return newValues[0] + elif newValues[0] == breaks: + # If the first value breaks it, done! + return breaks + elif newValues.count(breaks) >= 1: + # We don't need any values that occur after a break + i = newValues.index(breaks) + newValues = newValues[:i+1] + for i in range(len(newValues)): + newValues[i] = astFormat(newValues[i]) + # get the corresponding value + if i in ranks: + transferMetaData(a.values[ranks.index(i)], newValues[i]) + else: # it's in a list + for j in range(len(ranks)): + if type(ranks[j]) == list and i in ranks[j]: + transferMetaData(a.values[j].values[ranks[j].index(i)], newValues[i]) + break + a.values = newValues + return a + elif t == ast.BinOp: + l = constantFolding(a.left) + r = constantFolding(a.right) + # Hack to make hint chaining work- don't constant-fold filler strings! + if containsTokenStepString(l) or containsTokenStepString(r): + a.left = applyTransferLambda(a.left) + a.right = applyTransferLambda(a.right) + return a + if type(l) in builtInTypes and type(r) in builtInTypes: + try: + val = doBinaryOp(a.op, l, r) + if type(val) == float and val % 0.0001 != 0: # don't deal with trailing floats + pass + else: + tmp = astFormat(val) + transferMetaData(a, tmp) + return tmp + except: + # We have some kind of divide-by-zero issue. + # Therefore, don't calculate it! + pass + if type(l) in builtInTypes: + if type(r) == bool: + r = int(r) + # Commutative operations + elif type(r) == ast.BinOp and type(r.op) == type(a.op) and type(a.op) in [ast.Add, ast.Mult, ast.BitOr, ast.BitAnd, ast.BitXor]: + rLeft = constantFolding(r.left) + if type(rLeft) in builtInTypes: + try: + newLeft = astFormat(doBinaryOp(a.op, l, rLeft)) + transferMetaData(r.left, newLeft) + return ast.BinOp(newLeft, a.op, r.right) + except Exception as e: + pass + + # Empty string is often unneccessary + if type(l) == str and l == '': + if type(a.op) == ast.Add and eventualType(r) == str: + return r + elif type(a.op) == ast.Mult and eventualType(r) == int: + return '' + elif type(l) == bool: + l = int(l) + # 0 is often unneccessary + if l == 0 and eventualType(r) in [int, float]: + if type(a.op) in [ast.Add, ast.BitOr]: + # If it won't change the type + if type(l) == int or eventualType(r) == float: + return r + elif type(l) == float: # Cast it + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [r], []) + elif type(a.op) == ast.Sub: + tmpR = astFormat(r) + transferMetaData(a.right, tmpR) + newR = ast.UnaryOp(ast.USub(addedOtherOp=True), tmpR, addedOther=True) + if type(l) == int or eventualType(r) == float: + return newR + elif type(l) == float: + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [newR], []) + elif type(a.op) in [ast.Mult, ast.LShift, ast.RShift]: + # If either is a float, it's 0 + return 0.0 if float in [eventualType(r), type(l)] else 0 + elif type(a.op) in [ast.Div, ast.FloorDiv, ast.Mod]: + # Check if the right might be zero + if type(r) in builtInTypes and r != 0: + return 0.0 if float in [eventualType(r), type(l)] else 0 + # Same for 1 + elif l == 1: + if type(a.op) == ast.Mult and eventualType(r) in [int, float]: + if type(l) == int or eventualType(r) == float: + return r + elif type(l) == float: + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [r], []) + # No reason to make this a float if the other value has already been cast + elif type(l) == float and l == int(l): + if type(a.op) in [ast.Add, ast.Sub, ast.Mult, ast.Div] and eventualType(r) == float: + l = int(l) + # Some of the same operations are done with the right, but not all of them + if type(r) in builtInTypes: + if type(r) == str and r == '': + if type(a.op) == ast.Add and eventualType(l) == str: + return l + elif type(a.op) == ast.Mult and eventualType(l) == int: + return '' + elif type(r) == bool: + r = int(r) + else: + if r == 0 and eventualType(l) in [int, float]: + if type(a.op) in [ast.Add, ast.Sub, ast.LShift, ast.RShift, ast.BitOr]: + if type(r) == int or eventualType(l) == float: + return l + elif type(r) == float: + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [l], []) + elif type(a.op) == ast.Mult: + return 0.0 if float in [eventualType(l), type(r)] else 0 + elif r == 1: + if type(a.op) in [ast.Mult, ast.Div, ast.Pow] and eventualType(l) in [int, float]: + if type(r) == int or eventualType(l) == float: + return l + elif type(r) == float: + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [l], []) + elif type(a.op) == ast.FloorDiv and eventualType(l) == int: + if eventualType( r ) == int: + return l + elif eventualType( r ) == float: + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [l], []) + elif type(r) == float and r == int(r): + if type(a.op) in [ast.Add, ast.Sub, ast.Mult, ast.Div] and eventualType(l) == float: + r = int(r) + a.left = applyTransferLambda(a.left) + a.right = applyTransferLambda(a.right) + return a + elif t == ast.IfExp: + # Sometimes, we can simplify the statement + test = constantFolding(a.test) + b = constantFolding(a.body) + o = constantFolding(a.orelse) + + aTest = astFormat(test) + transferMetaData(a.test, aTest) + aB = astFormat(b) + transferMetaData(a.body, aB) + aO = astFormat(o) + transferMetaData(a.orelse, aO) + + if type(test) == bool: + return aB if test else aO # evaluate the if expression now + elif compareASTs(b, o, checkEquality=True) == 0: + return aB # if they're the same, no reason for the expression + a.test = aTest + a.body = aB + a.orelse = aO + return a + elif t == ast.Compare: + if len(a.ops) == 0 or len(a.comparators) == 0: + return True # No ops? Okay, empty case is true! + op = a.ops[0] + l = constantFolding(a.left) + r = constantFolding(a.comparators[0]) + # Hack to make hint chaining work- don't constant-fold filler strings! + if containsTokenStepString(l) or containsTokenStepString(r): + tmpLeft = astFormat(l) + transferMetaData(a.left, tmpLeft) + a.left = tmpLeft + tmpRight = astFormat(r) + transferMetaData(a.comparators[0], tmpRight) + a.comparators = [tmpRight] + return a + # Check whether the two sides are the same + comp = compareASTs(l, r, checkEquality=True) == 0 + if comp and (not couldCrash(l)) and type(op) in [ast.Lt, ast.Gt, ast.NotEq]: + tmp = ast.NameConstant(False) + transferMetaData(a, tmp) + return tmp + elif comp and (not couldCrash(l)) and type(op) in [ast.Eq, ast.LtE, ast.GtE]: + tmp = ast.NameConstant(True) + transferMetaData(a, tmp) + return tmp + if (type(l) in builtInTypes) and (type(r) in builtInTypes): + try: + result = astFormat(doCompare(op, l, r)) + transferMetaData(a, result) + return result + except: + pass + # Reduce the expressions when possible! + if type(l) == type(r) == ast.BinOp and type(l.op) == type(r.op) and not couldCrash(l) and not couldCrash(r): + if type(l.op) == ast.Add: + # Remove repeated values + unchanged = False + if compareASTs(l.left, r.left, checkEquality=True) == 0: + l = l.right + r = r.right + elif compareASTs(l.right, r.right, checkEquality=True) == 0: + l = l.left + r = r.left + elif compareASTs(l.left, r.right, checkEquality=True) == 0 and eventualType(l) in [int, float]: + l = l.right + r = r.left + elif compareASTs(l.right, r.left, checkEquality=True) == 0 and eventualType(l) in [int, float]: + l = l.left + r = r.right + else: + unchanged = True + if not unchanged: + tmpLeft = astFormat(l) + transferMetaData(a.left, tmpLeft) + a.left = tmpLeft + tmpRight = astFormat(r) + transferMetaData(a.comparators[0], tmpRight) + a.comparators = [tmpRight] + return constantFolding(a) # Repeat this check to see if we can keep reducing it + elif type(l.op) == ast.Sub: + unchanged = False + if compareASTs(l.left, r.left, checkEquality=True) == 0: + l = l.right + r = r.right + elif compareASTs(l.right, r.right, checkEquality=True) == 0: + l = l.left + r = r.left + else: + unchanged = True + if not unchanged: + tmpLeft = astFormat(l) + transferMetaData(a.left, tmpLeft) + a.left = tmpLeft + tmpRight = astFormat(r) + transferMetaData(a.comparators[0], tmpRight) + a.comparators = [tmpRight] + return constantFolding(a) + tmpLeft = astFormat(l) + transferMetaData(a.left, tmpLeft) + a.left = tmpLeft + tmpRight = astFormat(r) + transferMetaData(a.comparators[0], tmpRight) + a.comparators = [tmpRight] + return a + elif t == ast.Call: + # TODO: this can be done much better + a.func = applyTransferLambda(a.func) + + allConstant = True + tmpArgs = [] + for i in range(len(a.args)): + tmpArgs.append(constantFolding(a.args[i])) + if type(tmpArgs[i]) not in [int, float, bool, str]: + allConstant = False + if len(a.keywords) > 0: + allConstant = False + if allConstant and (type(a.func) == ast.Name) and (a.func.id in builtInFunctions.keys()) and \ + (a.func.id not in ["range", "raw_input", "input", "open", "randint", "random", "slice"]): + try: + result = apply(eval(a.func.id), tmpArgs) + transferMetaData(a, astFormat(result)) + return result + except: + # Not gonna happen unless it crashes + #log("transformations\tconstantFolding\tFunction crashed: " + str(a.func.id), "bug") + pass + for i in range(len(a.args)): + tmpArg = astFormat(tmpArgs[i]) + transferMetaData(a.args[i], tmpArg) + a.args[i] = tmpArg + return a + # This needs to be separate because the attribute is a string + elif t == ast.Attribute: + a.value = applyTransferLambda(a.value) + return a + elif t == ast.Slice: + if a.lower != None: + a.lower = applyTransferLambda(a.lower) + if a.upper != None: + a.upper = applyTransferLambda(a.upper) + if a.step != None: + a.step = applyTransferLambda(a.step) + return a + elif t == ast.Num: + return a.n + elif t == ast.Bytes: + return a.s + elif t == ast.Str: + # Don't do things to filler strings + if len(a.s) > 0 and isTokenStepString(a.s): + return a + return a.s + elif t == ast.NameConstant: + if a.value == True: + return True + elif a.value == False: + return False + elif a.value == None: + return None + elif t == ast.Name: + return a + else: # All statements, ast.Lambda, ast.Dict, ast.Set, ast.Repr, ast.Attribute, ast.Subscript, etc. + return applyToChildren(a, applyTransferLambda) + +def isMutatingFunction(a): + """Given a function call, this checks whether it might change the program state when run""" + if type(a) != ast.Call: # Can only call this on Calls! + log("transformations\tisMutatingFunction\tNot a Call: " + str(type(a)), "bug") + return True + + # Map of all static namesets + funMaps = { "math" : mathFunctions, "string" : builtInStringFunctions, + "str" : builtInStringFunctions, "list" : staticListFunctions, + "dict" : staticDictFunctions } + typeMaps = { str : "string", list : "list", dict : "dict" } + if type(a.func) == ast.Name: + funDict = builtInFunctions + funName = a.func.id + elif type(a.func) == ast.Attribute: + if type(a.func.value) == ast.Name and a.func.value.id in funMaps: + funDict = funMaps[a.func.value.id] + funName = a.func.attr + # if the item is calling a function directly + elif eventualType(a.func.value) in typeMaps: + funDict = funMaps[typeMaps[eventualType(a.func.value)]] + funName = a.func.attr + else: + return True + else: + return True # we don't know, so yes + + # TODO: deal with student's functions + return funName not in funDict + +def allVariablesUsed(a): + if not isinstance(a, ast.AST): + return [] + elif type(a) == ast.Name: + return [a] + variables = [] + for child in ast.iter_child_nodes(a): + variables += allVariablesUsed(child) + return variables + +def allVariableNamesUsed(a): + """Gathers all the variable names used in the ast""" + if not isinstance(a, ast.AST): + return [] + elif type(a) == ast.Name: + return [a.id] + elif type(a) == ast.Assign: + """In assignments, ignore all pure names used- they're being assigned to, not used""" + variables = allVariableNamesUsed(a.value) + for target in a.targets: + if type(target) == ast.Name: + pass + elif type(target) in [ast.Tuple, ast.List]: + for elt in target.elts: + if type(elt) != ast.Name: + variables += allVariableNamesUsed(elt) + else: + variables += allVariableNamesUsed(target) + return variables + elif type(a) == ast.AugAssign: + variables = allVariableNamesUsed(a.value) + variables += allVariableNamesUsed(a.target) + return variables + variables = [] + for child in ast.iter_child_nodes(a): + variables += allVariableNamesUsed(child) + return variables + +def addPropTag(a, globalId): + if not isinstance(a, ast.AST): + return a + a.propagatedVariable = True + if hasattr(a, "global_id"): + a.variableGlobalId = globalId + return applyToChildren(a, lambda x : addPropTag(x, globalId)) + +def propagateValues(a, liveVars): + """Propagate the given values through the AST whenever their variables occur""" + if ((not isinstance(a, ast.AST) or len(liveVars.keys()) == 0)): + return a + + if type(a) == ast.Name: + # Propagate the value if we have it! + if a.id in liveVars: + val = copy.deepcopy(liveVars[a.id]) + val.loadedVariable = True + if hasattr(a, "global_id"): + val.variableGlobalId = a.global_id + return applyToChildren(val, lambda x : addPropTag(x, a.global_id)) + else: + return a + elif type(a) == ast.Call: + # If something is mutated, it cannot be propagated anymore + if isMutatingFunction(a): + allVars = allVariablesUsed(a) + for var in allVars: + if (eventualType(var) not in [int, float, bool, str]): + if (var.id in liveVars): + del liveVars[var.id] + currentLiveVars = list(liveVars.keys()) + for liveVar in currentLiveVars: + varsWithin = allVariableNamesUsed(liveVars[liveVar]) + if var.id in varsWithin: + del liveVars[liveVar] + return a + elif type(a.func) == ast.Name and a.func.id in liveVars and \ + eventualType(liveVars[a.func.id]) in [int, float, complex, bytes, bool, type(None)]: + # Special case: don't move a simple value to the front of a Call + # because it will cause a compiler error instead of a runtime error + a.args = propagateValues(a.args, liveVars) + a.keywords = propagateValues(a.keywords, liveVars) + return a + elif type(a) == ast.Attribute: + if type(a.value) == ast.Name and a.value.id in liveVars and \ + eventualType(liveVars[a.value.id]) in [int, float, complex, bytes, bool, type(None)]: + # Don't move for the same reason as above + return a + return applyToChildren(a, lambda x: propagateValues(x, liveVars)) + +def hasMutatingFunction(a): + """Checks to see if the ast has any potentially mutating functions""" + if not isinstance(a, ast.AST): + return False + for node in ast.walk(a): + if type(a) == ast.Call: + if isMutatingFunction(a): + return True + return False + +def clearBlockVars(a, liveVars): + """Clear all the vars set in this block out of the live vars""" + if (not isinstance(a, ast.AST)) or len(liveVars.keys()) == 0: + return + + if type(a) in [ast.Assign, ast.AugAssign]: + if type(a) == ast.Assign: + targets = gatherAssignedVars(a.targets) + else: + targets = gatherAssignedVars([a.target]) + for target in targets: + varId = None + if type(target) == ast.Name: + varId = target.id + elif type(target.value) == ast.Name: + varId = target.value.id + if varId in liveVars: + del liveVars[varId] + + liveKeys = list(liveVars.keys()) + for var in liveKeys: + # Remove the variable and any variables in which it is used + if varId in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + return + elif type(a) == ast.Call: + if hasMutatingFunction(a): + for v in allVariablesUsed(a): + if eventualType(v) not in [int, float, bool, str]: + if v.id in liveVars: + del liveVars[v.id] + liveKeys = list(liveVars.keys()) + for var in liveKeys: + if v.id in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + return + elif type(a) == ast.For: + names = [] + if type(a.target) == ast.Name: + names = [a.target.id] + elif type(a.target) in [ast.Tuple, ast.List]: + for elt in a.target.elts: + if type(elt) == ast.Name: + names.append(elt.id) + elif type(elt) == ast.Subscript: + if type(elt.value) == ast.Name: + names.append(elt.value.id) + else: + log("transformations\tclearBlockVars\tFor target subscript not a name: " + str(type(elt.value)), "bug") + else: + log("transformations\tclearBlockVars\tFor target not a name: " + str(type(elt)), "bug") + elif type(a.target) == ast.Subscript: + if type(a.target.value) == ast.Name: + names.append(a.target.value.id) + else: + log("transformations\tclearBlockVars\tFor target subscript not a name: " + str(type(a.target.value)), "bug") + else: + log("transformations\tclearBlockVars\tFor target not a name: " + str(type(a.target)), "bug") + for name in names: + if name in liveVars: + del liveVars[name] + + liveKeys = list(liveVars.keys()) + for var in liveKeys: + # Remove the variable and any variables in which it is used + if name in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + + for child in ast.iter_child_nodes(a): + clearBlockVars(child, liveVars) + +def copyPropagation(a, liveVars=None, inLoop=False): + """Propagate variables into the tree, when possible""" + if liveVars == None: + liveVars = { } + if type(a) == ast.Module: + a.body = copyPropagation(a.body) + return a + if type(a) == ast.FunctionDef: + a.body = copyPropagation(a.body, liveVars=liveVars) + return a + + if type(a) == list: + i = 0 + while i < len(a): + deleteLine = False + if type(a[i]) == ast.FunctionDef: + a[i].body = copyPropagation(a[i].body, liveVars=copy.deepcopy(liveVars)) + elif type(a[i]) == ast.ClassDef: + # TODO: can we propagate values through everything after here? + for j in range(len(a[i].body)): + if type(a[i].body[j]) == ast.FunctionDef: + a[i].body[j] = copyPropagation(a[i].body[j]) + elif type(a[i]) == ast.Assign: + # In assignments, propagate values into the right side and move the left side into the live vars + a[i].value = propagateValues(a[i].value, liveVars) + target = a[i].targets[0] + + if type(target) in [ast.Name, ast.Subscript, ast.Attribute]: + varId = None + # In plain names, we can update the liveVars + if type(target) == ast.Name: + varId = target.id + if inLoop or couldCrash(a[i].value) or eventualType(a[i].value) not in [bool, int, float, str, tuple]: + # Remove this variable from the live vars + if varId in liveVars: + del liveVars[varId] + else: + liveVars[varId] = a[i].value + # For other values, we can at least clear out liveVars correctly + # TODO: can we expand this? + elif target.value == ast.Name: + varId = target.value.id + + # Now, update the live vars based on anything reset by the new target + liveKeys = list(liveVars.keys()) + for var in liveKeys: + # If the var we're replacing was used elsewhere, that value will no longer be the same + if varId in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + elif type(target) in [ast.Tuple, ast.List]: + # Copy the values, if we can match them + if type(a[i].value) in [ast.Tuple, ast.List] and len(target.elts) == len(a[i].value.elts): + for j in range(len(target.elts)): + if type(target.elts[j]) == ast.Name: + if (not couldCrash(a[i].value.elts[j])): + liveVars[target.elts[j]] = a[i].value.elts[j] + else: + if target.elts[j] in liveVars: + del liveVars[target.elts[j]] + + # Then get rid of any overwrites + for e in target.elts: + if type(e) in [ast.Name, ast.Subscript, ast.Attribute]: + varId = None + if type(e) == ast.Name: + varId = e.id + elif type(e.value) == ast.Name: + varId = e.value.id + + liveKeys = list(liveVars.keys()) + for var in liveKeys: + if varId in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + else: + log("transformations\tcopyPropagation\tWeird assign type: " + str(type(e)), "bug") + elif type(a[i]) == ast.AugAssign: + a[i].value = propagateValues(a[i].value, liveVars) + assns = gatherAssignedVarIds([a[i].target]) + for target in assns: + if target in liveVars: + del liveVars[target] + elif type(a[i]) == ast.For: + # FIRST, propagate values into the iter + if type(a[i].iter) != ast.Name: # if it IS a name, don't replace it! + # Otherwise, we propagate first since this is evaluated once + a[i].iter = propagateValues(a[i].iter, liveVars) + + # We reset the target variable, so reset the live vars + names = [] + if type(a[i].target) == ast.Name: + names = [a[i].target.id] + elif type(a[i].target) in [ast.Tuple, ast.List]: + for elt in a[i].target.elts: + if type(elt) == ast.Name: + names.append(elt.id) + elif type(elt) == ast.Subscript: + if type(elt.value) == ast.Name: + names.append(elt.value.id) + else: + log("transformations\tcopyPropagation\tFor target subscript not a name: " + str(type(elt.value)) + "\t" + printFunction(elt.value), "bug") + else: + log("transformations\tcopyPropagation\tFor target not a name: " + str(type(elt)) + "\t" + printFunction(elt), "bug") + elif type(a[i].target) == ast.Subscript: + if type(a[i].target.value) == ast.Name: + names.append(a[i].target.value.id) + else: + log("transformations\tcopyPropagation\tFor target subscript not a name: " + str(type(a[i].target.value)) + "\t" + printFunction(a[i].target.value), "bug") + else: + log("transformations\tcopyPropagation\tFor target not a name: " + str(type(a[i].target)) + "\t" + printFunction(a[i].target), "bug") + + for name in names: + liveKeys = list(liveVars.keys()) + for var in liveKeys: + if name in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + if name in liveVars: + del liveVars[name] + clearBlockVars(a[i], liveVars) + a[i].body = copyPropagation(a[i].body, copy.deepcopy(liveVars), inLoop=True) + a[i].orelse = copyPropagation(a[i].orelse, copy.deepcopy(liveVars), inLoop=True) + elif type(a[i]) == ast.While: + clearBlockVars(a[i], liveVars) + a[i].test = propagateValues(a[i].test, liveVars) + a[i].body = copyPropagation(a[i].body, copy.deepcopy(liveVars), inLoop=True) + a[i].orelse = copyPropagation(a[i].orelse, copy.deepcopy(liveVars), inLoop=True) + elif type(a[i]) == ast.If: + a[i].test = propagateValues(a[i].test, liveVars) + liveVars1 = copy.deepcopy(liveVars) + liveVars2 = copy.deepcopy(liveVars) + a[i].body = copyPropagation(a[i].body, liveVars1) + a[i].orelse = copyPropagation(a[i].orelse, liveVars2) + liveVars.clear() + # We can keep any values that occur in both + for key in liveVars1: + if key in liveVars2: + if compareASTs(liveVars1[key], liveVars2[key], checkEquality=True) == 0: + liveVars[key] = liveVars1[key] + # TODO: think more deeply about how this should work + elif type(a[i]) == ast.Try: + a[i].body = copyPropagation(a[i].body, liveVars) + for handler in a[i].handlers: + handler.body = copyPropagation(handler.body, liveVars) + a[i].orelse = copyPropagation(a[i].orelse, liveVars) + a[i].finalbody = copyPropagation(a[i].finalbody, liveVars) + elif type(a[i]) == ast.With: + a[i].body = copyPropagation(a[i].body, liveVars) + # With regular statements, just propagate the values + elif type(a[i]) in [ast.Return, ast.Delete, ast.Raise, ast.Assert, ast.Expr]: + propagateValues(a[i], liveVars) + # Breaks and Continues mess everything up + elif type(a[i]) in [ast.Break, ast.Continue]: + break + # These are not affected by this function + elif type(a[i]) in [ast.Import, ast.ImportFrom, ast.Global, ast.Pass]: + pass + else: + log("transformations\tcopyPropagation\tNot implemented: " + str(type(a[i])), "bug") + i += 1 + return a + else: + log("transformations\tcopyPropagation\tNot a list: " + str(type(a)), "bug") + return a + +def deadCodeRemoval(a, liveVars=None, keepPrints=True, inLoop=False): + """Remove any code which will not be reached or used.""" + """LiveVars keeps track of the variables that will be necessary""" + if liveVars == None: + liveVars = set() + if type(a) == ast.Module: + # Remove functions that will be overwritten anyway + namesSeen = [] + i = len(a.body) - 1 + while i >= 0: + if type(a.body[i]) == ast.FunctionDef: + if a.body[i].name in namesSeen: + # SPECIAL CHECK! Actually, the function will cause the code to crash if some of the args have the same name. Don't delete it then. + argNames = [] + for arg in a.body[i].args.args: + if arg.arg in argNames: + break + else: + argNames.append(arg.arg) + else: # only remove this if the args won't break it + a.body.pop(i) + else: + namesSeen.append(a.body[i].name) + elif type(a.body[i]) == ast.Assign: + namesSeen += gatherAssignedVars(a.body[i].targets) + i -= 1 + liveVars |= set(namesSeen) # make sure all global names are used! + + if type(a) in [ast.Module, ast.FunctionDef]: + if type(a) == ast.Module and len(a.body) == 0: + return a # just don't mess with it + gid = a.body[0].global_id if len(a.body) > 0 and hasattr(a.body[0], "global_id") else None + a.body = deadCodeRemoval(a.body, liveVars=liveVars, keepPrints=keepPrints, inLoop=inLoop) + if len(a.body) == 0: + a.body = [ast.Pass(removedLines=True)] if gid == None else [ast.Pass(removedLines=True, global_id=gid)] + return a + + if type(a) == list: + i = len(a) - 1 + while i >= 0 and len(a) > 0: + if i >= len(a): + i = len(a) - 1 # just in case + stmt = a[i] + t = type(stmt) + # TODO: get rid of these if they aren't live + if t in [ast.FunctionDef, ast.ClassDef]: + newLiveVars = set() + gid = a[i].body[0].global_id if len(a[i].body) > 0 and hasattr(a[i].body[0], "global_id") else None + a[i] = deadCodeRemoval(a[i], liveVars=newLiveVars, keepPrints=keepPrints, inLoop=inLoop) + liveVars |= newLiveVars + # Empty functions are useless! + if len(a[i].body) == 0: + a[i].body = [ast.Pass(removedLines=True)] if gid == None else [ast.Pass(removedLines=True, global_id=gid)] + elif t == ast.Return: + # Get rid of everything that happens after this! + a = a[:i+1] + # Replace the variables + liveVars.clear() + liveVars |= set(allVariableNamesUsed(stmt)) + elif t in [ast.Delete, ast.Assert]: + # Just add all variables used + liveVars |= set(allVariableNamesUsed(stmt)) + elif t == ast.Assign: + # Check to see if the names being assigned are in the set of live variables + allDead = True + allTargets = gatherAssignedVars(stmt.targets) + allNamesUsed = allVariableNamesUsed(stmt.value) + for target in allTargets: + if type(target) == ast.Name and (target.id in liveVars or target.id in allNamesUsed): + if target.id in liveVars: + liveVars.remove(target.id) + allDead = False + elif type(target) in [ast.Subscript, ast.Attribute]: + liveVars |= set(allVariableNamesUsed(target)) + allDead = False + # Also, check if the variable itself is contained in the value, because that can crash too + # If none are used, we can delete this line. Otherwise, use the value's vars + if allDead and (not couldCrash(stmt)) and (not containsTokenStepString(stmt)): + a.pop(i) + else: + liveVars |= set(allVariableNamesUsed(stmt.value)) + elif t == ast.AugAssign: + liveVars |= set(allVariableNamesUsed(stmt.target)) + liveVars |= set(allVariableNamesUsed(stmt.value)) + elif t == ast.For: + # If there is no use of break, there's no reason to use else with the loop, + # so move the lines outside and go over them separately + if len(stmt.orelse) > 0 and countOccurances(stmt, ast.Break) == 0: + lines = stmt.orelse + stmt.orelse = [] + a[i:i+1] = [stmt] + lines + i += len(lines) + continue # don't subtract one + + targetNames = [] + if type(a[i].target) == ast.Name: + targetNames = [a[i].target.id] + elif type(a[i].target) in [ast.Tuple, ast.List]: + for elt in a[i].target.elts: + if type(elt) == ast.Name: + targetNames.append(elt.id) + elif type(elt) == ast.Subscript: + if type(elt.value) == ast.Name: + targetNames.append(elt.value.id) + else: + log("transformations\tdeadCodeRemoval\tFor target subscript not a name: " + str(type(elt.value)) + "\t" + printFunction(elt.value), "bug") + else: + log("transformations\tdeadCodeRemoval\tFor target not a name: " + str(type(elt)) + "\t" + printFunction(elt), "bug") + elif type(a[i].target) == ast.Subscript: + if type(a[i].target.value) == ast.Name: + targetNames.append(a[i].target.value.id) + else: + log("transformations\tdeadCodeRemoval\tFor target subscript not a name: " + str(type(a[i].target.value)) + "\t" + printFunction(a[i].target.value), "bug") + else: + log("transformations\tdeadCodeRemoval\tFor target not a name: " + str(type(a[i].target)) + "\t" + printFunction(a[i].target), "bug") + + # We need to make ALL variables in the loop live, since they update continuously + liveVars |= set(allVariableNamesUsed(stmt)) + gid = stmt.body[0].global_id if len(stmt.body) > 0 and hasattr(stmt.body[0], "global_id") else None + stmt.body = deadCodeRemoval(stmt.body, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=True) + stmt.orelse = deadCodeRemoval(stmt.orelse, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=inLoop) + # If the body is empty and we don't need the target, get rid of it! + if len(stmt.body) == 0: + for name in targetNames: + if name in liveVars: + stmt.body = [ast.Pass(removedLines=True)] if gid == None else [ast.Pass(removedLines=True, global_id=gid)] + break + else: + if couldCrash(stmt.iter) or containsTokenStepString(stmt.iter): + a[i] = ast.Expr(stmt.iter, collapsedExpr=True) + else: + a.pop(i) + if len(stmt.orelse) > 0: + a[i:i+1] = a[i] + stmt.orelse + + # The names are wiped UPDATE - NOPE, what if we never enter the loop? + #for name in targetNames: + # liveVars.remove(name) + elif t == ast.While: + # If there is no use of break, there's no reason to use else with the loop, + # so move the lines outside and go over them separately + if len(stmt.orelse) > 0 and countOccurances(stmt, ast.Break) == 0: + lines = stmt.orelse + stmt.orelse = [] + a[i:i+1] = [stmt] + lines + i += len(lines) + continue + + # We need to make ALL variables in the loop live, since they update continuously + liveVars |= set(allVariableNamesUsed(stmt)) + old_global_id = stmt.body[0].global_id + stmt.body = deadCodeRemoval(stmt.body, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=True) + stmt.orelse = deadCodeRemoval(stmt.orelse, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=inLoop) + # If the body is empty, get rid of it! + if len(stmt.body) == 0: + stmt.body = [ast.Pass(removedLines=True, global_id=old_global_id)] + elif t == ast.If: + # First, if True/False, just replace it with the lines + test = a[i].test + if type(test) == ast.NameConstant and test.value in [True, False]: + assignedVars = getAllAssignedVars(a[i]) + for var in assignedVars: + # UNLESS we have a weird variable assignment problem + if var.id[0] == "g" and hasattr(var, "originalId"): + log("canonicalize\tdeadCodeRemoval\tWeird global variable: " + printFunction(a[i]), "bug") + break + else: + if test.value == True: + a[i:i+1] = a[i].body + else: + a[i:i+1] = a[i].orelse + continue + # For if statements, see if you can shorten things + liveVars1 = copy.deepcopy(liveVars) + liveVars2 = copy.deepcopy(liveVars) + stmt.body = deadCodeRemoval(stmt.body, liveVars1, keepPrints=keepPrints, inLoop=inLoop) + stmt.orelse = deadCodeRemoval(stmt.orelse, liveVars2, keepPrints=keepPrints, inLoop=inLoop) + liveVars.clear() + allVars = liveVars1 | liveVars2 | set(allVariableNamesUsed(stmt.test)) + liveVars |= allVars + if len(stmt.body) == 0 and len(stmt.orelse) == 0: + # Get rid of the if and keep going + if couldCrash(stmt.test) or containsTokenStepString(stmt.test): + newStmt = ast.Expr(stmt.test, collapsedExpr=True) + transferMetaData(stmt, newStmt) + a[i] = newStmt + else: + a.pop(i) + i -= 1 + continue + if len(stmt.body) == 0: + # If the body is empty, switch it with the else + stmt.test = deMorganize(ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True)) + (stmt.body, stmt.orelse) = (stmt.orelse, stmt.body) + if len(stmt.orelse) == 0: + # See if we can make the rest of the function the else statement + if type(stmt.body[-1]) == type(a[-1]) == ast.Return: + # If the if is larger than the rest, switch them! + if len(stmt.body) > len(a[i+1:]): + stmt.test = deMorganize(ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True)) + (a[i+1:], stmt.body) = (stmt.body, a[i+1:]) + else: + # Check to see if we should switch the if and else parts + if len(stmt.body) > len(stmt.orelse): + stmt.test = deMorganize(ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True)) + (stmt.body, stmt.orelse) = (stmt.orelse, stmt.body) + elif t == ast.Import: + if len(stmt.names) == 0: + a.pop(i) + elif t == ast.Global: + j = 0 + while j < len(a.names): + if a.names[j] not in liveVars: + a.names.pop(j) + else: + j += 1 + elif t == ast.Expr: + # Remove the line if it won't crash things. + if couldCrash(stmt) or containsTokenStepString(stmt): + liveVars |= set(allVariableNamesUsed(stmt)) + else: + # check whether any of these variables might crash the program + # I know, it's weird, but occasionally a student might use a var before defining it + allVars = allVariableNamesUsed(stmt) + for j in range(i): + if type(a[j]) == ast.Assign: + for id in gatherAssignedVarIds(a[j].targets): + if id in allVars: + allVars.remove(id) + if len(allVars) > 0: + liveVars |= set(allVariableNamesUsed(stmt)) + else: + a.pop(i) + # for now, just be careful with these types of statements + elif t in [ast.With, ast.Raise, ast.Try]: + liveVars |= set(allVariableNamesUsed(stmt)) + elif t == ast.Pass: + a.pop(i) # pass does *nothing* + elif t in [ast.Continue, ast.Break]: + if inLoop: # If we're in a loop, nothing that follows matters! Otherwise, leave it alone, this will just crash. + a = a[:i+1] + break + # We don't know what they're doing- leave it alone + elif t in [ast.ImportFrom]: + pass + # Have not yet implemented these + else: + log("transformations\tdeadCodeRemoval\tNot implemented: " + str(type(stmt)), "bug") + i -= 1 + return a + else: + log("transformations\tdeadCodeRemoval\tNot a list: " + str(a), "bug") + return a + +### ORDERING FUNCTIONS ### + +def getKeyDict(d, key): + if key not in d: + d[key] = { "self" : key } + return d[key] + +def traverseTrail(d, trail): + temp = d + for key in trail: + temp = temp[key] + return temp + +def areDisjoint(a, b): + """Are the sets of values that satisfy these two boolean constraints disjoint?""" + # The easiest way to be disjoint is to have comparisons that cover different areas + if type(a) == type(b) == ast.Compare: + aop = a.ops[0] + bop = b.ops[0] + aLeft = a.left + aRight = a.comparators[0] + bLeft = b.left + bRight = b.comparators[0] + alblComp = compareASTs(aLeft, bLeft, checkEquality=True) + albrComp = compareASTs(aLeft, bRight, checkEquality=True) + arblComp = compareASTs(aRight, bLeft, checkEquality=True) + arbrComp = compareASTs(aRight, bRight, checkEquality=True) + altype = type(aLeft) in [ast.Num, ast.Str] + artype = type(aRight) in [ast.Num, ast.Str] + bltype = type(bLeft) in [ast.Num, ast.Str] + brtype = type(bRight) in [ast.Num, ast.Str] + + if (type(aop) == ast.Eq and type(bop) == ast.NotEq) or \ + (type(bop) == ast.Eq and type(aop) == ast.NotEq): + # x == y, x != y + if (alblComp == 0 and arbrComp == 0) or (albrComp == 0 and arblComp == 0): + return True + elif type(aop) == type(bop) == ast.Eq: + if (alblComp == 0 and arbrComp == 0) or (albrComp == 0 and arblComp == 0): + return False + # x = num1, x = num2 + elif alblComp == 0 and artype and brtype: + return True + elif albrComp == 0 and artype and bltype: + return True + elif arblComp == 0 and altype and brtype: + return True + elif arbrComp == 0 and altype and bltype: + return True + elif (type(aop) == ast.Lt and type(bop) == ast.GtE) or \ + (type(aop) == ast.Gt and type(bop) == ast.LtE) or \ + (type(aop) == ast.LtE and type(bop) == ast.Gt) or \ + (type(aop) == ast.GtE and type(bop) == ast.Lt) or \ + (type(aop) == ast.Is and type(bop) == ast.IsNot) or \ + (type(aop) == ast.IsNot and type(bop) == ast.Is) or \ + (type(aop) == ast.In and type(bop) == ast.NotIn) or \ + (type(aop) == ast.NotIn and type(bop) == ast.In): + if alblComp == 0 and arbrComp == 0: + return True + elif (type(aop) == ast.Lt and type(bop) == ast.LtE) or \ + (type(aop) == ast.Gt and type(bop) == ast.GtE) or \ + (type(aop) == ast.LtE and type(bop) == ast.Lt) or \ + (type(aop) == ast.GtE and type(bop) == ast.Gt): + if albrComp == 0 and arblComp == 0: + return True + elif type(a) == type(b) == ast.BoolOp: + return False # for now- TODO: when is this not true? + elif type(a) == ast.UnaryOp and type(a.op) == ast.Not: + if compareASTs(a.operand, b, checkEquality=True) == 0: + return True + elif type(b) == ast.UnaryOp and type(b.op) == ast.Not: + if compareASTs(b.operand, a, checkEquality=True) == 0: + return True + return False + +def crashesOn(a): + """Determines where the expression might crash""" + # TODO: integrate typeCrashes + if not isinstance(a, ast.AST): + return [] + if type(a) == ast.BinOp: + l = eventualType(a.left) + r = eventualType(a.right) + if type(a.op) == ast.Add: + if not ((l == r == str) or (l in [int, float] and r in [int, float])): + return [a] + elif type(a.op) == ast.Mult: + if not ((l == str and r == int) or (l == int and r == str) or \ + (l in [int, float] and r in [int, float])): + return [a] + elif type(a.op) in [ast.Sub, ast.Pow, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd]: + if l not in [int, float] or r not in [int, float]: + return [a] + else: # ast.Div, ast.FloorDiv, ast.Mod + if (type(a.right) != ast.Num or a.right.n == 0) or \ + (l not in [int, float] or r not in [int, float]): + return [a] + elif type(a) == ast.UnaryOp: + if type(a.op) in [ast.UAdd, ast.USub]: + if eventualType(a.operand) not in [int, float]: + return [a] + elif type(a.op) == ast.Invert: + if eventualType(a.operand) != int: + return [a] + elif type(a) == ast.Compare: + if len(a.ops) != len(a.comparators): + return [a] + elif type(a.ops[0]) in [ast.In, ast.NotIn] and not isIterableType(eventualType(a.comparators[0])): + return [a] + elif type(a.ops[0]) in [ast.Lt, ast.LtE, ast.Gt, ast.GtE]: + # In Python3, you can't compare different types. BOOOOOO!! + firstType = eventualType(a.left) + if firstType == None: + return [a] + for comp in a.comparators: + if eventualType(comp) != firstType: + return [a] + elif type(a) == ast.Call: + env = [] # TODO: what if the environments aren't imported? + funMaps = { "math" : mathFunctions, "string" : builtInStringFunctions } + safeFunMaps = { "math" : safeMathFunctions, "string" : safeStringFunctions } + if type(a.func) == ast.Name: + funDict = builtInFunctions + safeFuns = builtInSafeFunctions + funName = a.func.id + elif type(a.func) == ast.Attribute: + if type(a.func.value) == ast.Name and a.func.value.id in funMaps: + funDict = funMaps[a.func.value.id] + safeFuns = safeFunMaps[a.func.value.id] + funName = a.func.attr + elif eventualType(a.func.value) == str: + funDict = funMaps["string"] + safeFuns = safeFunMaps["string"] + funName = a.func.attr + else: # including list and dict + return [a] + else: + return [a] + + runOnce = 0 # So we can break + while (runOnce == 0): + if funName in safeFuns: + argTypes = [] + for i in range(len(a.args)): + eventual = eventualType(a.args[i]) + if eventual == None: + return [a] + argTypes.append(eventual) + + if funName in ["max", "min"]: + break # Special functions + + for key in funDict[funName]: + if len(key) != len(argTypes): + continue + for i in range(len(key)): + if not (key[i] == argTypes[i] or issubclass(argTypes[i], key[i])): + break + else: + break + else: + return [a] + break # found one that works + else: + return [a] + elif type(a) == ast.Subscript: + if eventualType(a.value) not in [str, list, tuple]: + return [a] + elif type(a) == ast.Name: + # If it's an undefined variable, it might crash + if hasattr(a, "randomVar"): + return [a] + elif type(a) == ast.Slice: + if a.lower != None and eventualType(a.lower) != int: + return [a] + if a.upper != None and eventualType(a.upper) != int: + return [a] + if a.step != None and eventualType(a.step) != int: + return [a] + elif type(a) in [ast.Assert, ast.Import, ast.ImportFrom, ast.Attribute, ast.Index]: + return [a] + + allCrashes = [] + for child in ast.iter_child_nodes(a): + allCrashes += crashesOn(child) + return allCrashes + +def isNegation(a, b): + """Is a the negation of b?""" + return compareASTs(deMorganize(ast.UnaryOp(ast.Not(), deepcopy(a))), b, checkEquality=True) == 0 + +def reverse(op): + """Reverse the direction of the comparison for normalization purposes""" + rev = not op.reversed if hasattr(op, "reversed") else True + if type(op) == ast.Gt: + newOp = ast.Lt() + transferMetaData(op, newOp) + newOp.reversed = rev + return newOp + elif type(op) == ast.GtE: + newOp = ast.LtE() + transferMetaData(op, newOp) + newOp.reversed = rev + return newOp + else: + return op # Do not change! + +def orderCommutativeOperations(a): + """Order all expressions that are in commutative operations""" + """TODO: add commutative function lines?""" + if not isinstance(a, ast.AST): + return a + # If branches can be commutative as long as their tests are disjoint + if type(a) == ast.If: + a = applyToChildren(a, orderCommutativeOperations) + # If the else is (strictly) shorter than the body, switch them + if len(a.orelse) != 0 and len(a.body) > len(a.orelse): + newTest = ast.UnaryOp(ast.Not(addedNotOp=True), a.test) + transferMetaData(a.test, newTest) + newTest.negated = True + newTest = deMorganize(newTest) + a.test = newTest + (a.body,a.orelse) = (a.orelse,a.body) + + # Then collect all the branches. The leftover orelse is the final else + branches = [(a.test, a.body, a.global_id)] + orElse = a.orelse + while len(orElse) == 1 and type(orElse[0]) == ast.If: + branches.append((orElse[0].test, orElse[0].body, orElse[0].global_id)) + orElse = orElse[0].orelse + + # If we have branches to order... + if len(branches) != 1: + # Sort the branches based on their tests + # We have to sort carefully because of the possibility for crashing + isSorted = False + while not isSorted: + isSorted = True + for i in range(len(branches)-1): + # First, do we even want to swap these two? + # Branch tests MUST be disjoint to be swapped- otherwise, we break semantics + if areDisjoint(branches[i][0], branches[i+1][0]) and \ + compareASTs(branches[i][0], branches[i+1][0]) > 0: + if not (couldCrash(branches[i][0]) or couldCrash(branches[i+1][0])): + (branches[i],branches[i+1]) = (branches[i+1],branches[i]) + isSorted = False + # Two values can be swapped if they crash on the SAME thing + elif couldCrash(branches[i][0]) and couldCrash(branches[i+1][0]): + # Check to see if they crash on the same things + l1 = sorted(crashesOn(branches[i][0]), key=functools.cmp_to_key(compareASTs)) + l2 = sorted(crashesOn(branches[i+1][0]), key=functools.cmp_to_key(compareASTs)) + if compareASTs(l1, l2, checkEquality=True) == 0: + (branches[i],branches[i+1]) = (branches[i+1],branches[i]) + isSorted = False + # Do our last two branches nicely form an if/else already? + if len(orElse) == 0 and isNegation(branches[-1][0], branches[-2][0]): + starter = branches[-1][1] # skip the if + else: + starter = [ast.If(branches[-1][0], branches[-1][1], orElse, global_id=branches[-1][2])] + # Create the new conditional tree + for i in range(len(branches)-2, -1, -1): + starter = [ast.If(branches[i][0], branches[i][1], starter, global_id=branches[i][2])] + a = starter[0] + return a + elif type(a) == ast.BoolOp: + # If all the values are booleans and won't crash, we can sort them + canSort = True + for i in range(len(a.values)): + a.values[i] = orderCommutativeOperations(a.values[i]) + if couldCrash(a.values[i]) or eventualType(a.values[i]) != bool or containsTokenStepString(a.values[i]): + canSort = False + + if canSort: + a.values = sorted(a.values, key=functools.cmp_to_key(compareASTs)) + else: + # Even if there are some problems, we can partially sort. See above + isSorted = False + while not isSorted: + isSorted = True + for i in range(len(a.values)-1): + if compareASTs(a.values[i], a.values[i+1]) > 0 and \ + eventualType(a.values[i]) == bool and eventualType(a.values[i+1]) == bool: + if not (couldCrash(a.values[i]) or couldCrash(a.values[i+1])): + (a.values[i],a.values[i+1]) = (a.values[i+1],a.values[i]) + isSorted = False + # Two values can also be swapped if they crash on the SAME thing + elif couldCrash(a.values[i]) and couldCrash(a.values[i+1]): + # Check to see if they crash on the same things + l1 = sorted(crashesOn(a.values[i]), key=functools.cmp_to_key(compareASTs)) + l2 = sorted(crashesOn(a.values[i+1]), key=functools.cmp_to_key(compareASTs)) + if compareASTs(l1, l2, checkEquality=True) == 0: + (a.values[i],a.values[i+1]) = (a.values[i+1],a.values[i]) + isSorted = False + return a + elif type(a) == ast.BinOp: + top = type(a.op) + l = a.left = orderCommutativeOperations(a.left) + r = a.right = orderCommutativeOperations(a.right) + + # Don't reorder if we're currently walking through hint steps + if containsTokenStepString(l) or containsTokenStepString(r): + return a + + # TODO: what about possible crashes? + # Certain operands are commutative + if (top in [ast.Mult, ast.BitOr, ast.BitXor, ast.BitAnd]) or \ + ((top == ast.Add) and ((eventualType(l) in [int, float, bool]) or \ + (eventualType(r) in [int, float, bool]))): + # Break the chain of binary operations into a list of the + # operands over the same op, then sort the operands + operands = [[l, a.op], [r, None]] + changeMade = True + i = 0 + while i < len(operands): + [operand, op] = operands[i] + if type(operand) == ast.BinOp and type(operand.op) == top: + operands[i:i+1] = [[operand.left, operand.op], [operand.right, op]] + else: + i += 1 + operands = sorted(operands, key=functools.cmp_to_key(lambda x,y : compareASTs(x[0], y[0]))) + for i in range(len(operands)-1): # push all the ops forward + if operands[i][1] == None: + operands[i][1] = operands[i+1][1] + operands[i+1][1] = None + # Then put them back into a single expression, descending to the left + left = operands[0][0] + for i in range(1, len(operands)): + left = ast.BinOp(left, operands[i-1][1], operands[i][0], orderedBinOp=True) + transferMetaData(a, left) + return left + elif top == ast.Add: + # This might be concatenation, not addition + if type( r ) == ast.BinOp and type(r.op) == top: + # We want the operators to descend to the left + a.left = orderCommutativeOperations(ast.BinOp(l, r.op, r.left, global_id=r.global_id)) + a.right = r.right + return a + elif type(a) == ast.Dict: + for i in range(len(a.keys)): + a.keys[i] = orderCommutativeOperations(a.keys[i]) + a.values[i] = orderCommutativeOperations(a.values[i]) + + pairs = list(zip(a.keys, a.values)) + pairs.sort(key=functools.cmp_to_key(lambda x,y : compareASTs(x[0],y[0]))) # sort by keys + k, v = zip(*pairs) if len(pairs) > 0 else ([], []) + a.keys = list(k) + a.values = list(v) + return a + elif type(a) == ast.Compare: + l = a.left = orderCommutativeOperations(a.left) + r = orderCommutativeOperations(a.comparators[0]) + a.comparators = [r] + + # Don't reorder when we're doing hint steps + if containsTokenStepString(l) or containsTokenStepString(r): + return a + + if (type(a.ops[0]) in [ast.Eq, ast.NotEq]): + # Equals and not-equals are commutative + if compareASTs(l, r) > 0: + a.left, a.comparators[0] = a.comparators[0], a.left + elif (type(a.ops[0]) in [ast.Gt, ast.GtE]): + # We'll always use < and <=, just so everything's the same + a.ops = [reverse(a.ops[0])] + a.left, a.comparators[0] = a.comparators[0], a.left + elif (type(a.ops[0]) in [ast.In, ast.NotIn]): + if type(r) == ast.List: + # If it's a list of items, sort the list + # TODO: should we implement crashable sorting here? + for i in range(len(r.elts)): + if couldCrash(r.elts[i]): + break # don't sort if there'a a crash! + else: + r.elts = sorted(r.elts, key=functools.cmp_to_key(compareASTs)) + # Then remove duplicates + i = 0 + while i < len(r.elts) - 1: + if compareASTs(r.elts[i], r.elts[i+1], checkEquality=True) == 0: + r.elts.pop(i+1) + else: + i += 1 + return a + elif type(a) == ast.Call: + if type(a.func) == ast.Name: + # These functions are commutative and show up a lot + if a.func.id in ["min", "max"]: + crashable = False + for i in range(len(a.args)): + a.args[i] = orderCommutativeOperations(a.args[i]) + if couldCrash(a.args[i]) or containsTokenStepString(a.args[i]): + crashable = True + # TODO: crashable sorting here? + if not crashable: + a.args = sorted(a.args, key=functools.cmp_to_key(compareASTs)) + return a + return applyToChildren(a, orderCommutativeOperations) + +def deMorganize(a): + """Apply De Morgan's law throughout the code in order to canonicalize""" + if not isinstance(a, ast.AST): + return a + # We only care about statements beginning with not + if type(a) == ast.UnaryOp and type(a.op) == ast.Not: + oper = a.operand + top = type(oper) + + # not (blah and gah) == (not blah or not gah) + if top == ast.BoolOp: + oper.op = negate(oper.op) + for i in range(len(oper.values)): + oper.values[i] = deMorganize(negate(oper.values[i])) + oper.negated = not oper.negated if hasattr(oper, "negated") else True + transferMetaData(a, oper) + return oper + # not a < b == a >= b + elif top == ast.Compare: + oper.left = deMorganize(oper.left) + oper.ops = [negate(oper.ops[0])] + oper.comparators = [deMorganize(oper.comparators[0])] + oper.negated = not oper.negated if hasattr(oper, "negated") else True + transferMetaData(a, oper) + return oper + # not not blah == blah + elif top == ast.UnaryOp and type(oper.op) == ast.Not: + oper.operand = deMorganize(oper.operand) + if eventualType(oper.operand) != bool: + return a + oper.operand.negated = not oper.operand.negated if hasattr(oper.operand, "negated") else True + return oper.operand + elif top == ast.NameConstant: + if oper.value in [True, False]: + oper = negate(oper) + transferMetaData(a, oper) + return oper + elif oper.value == None: + tmp = ast.NameConstant(True) + transferMetaData(a, tmp) + tmp.negated = True + return tmp + else: + log("Unknown NameConstant: " + str(oper.value), "bug") + + return applyToChildren(a, deMorganize) + +##### CLEANUP FUNCTIONS ##### + +def cleanupEquals(a): + """Gets rid of silly blah == True statements that students make""" + if not isinstance(a, ast.AST): + return a + if type(a) == ast.Call: + a.func = cleanupEquals(a.func) + for i in range(len(a.args)): + # But test expressions don't carry through to function arguments + a.args[i] = cleanupEquals(a.args[i]) + return a + elif type(a) == ast.Compare and type(a.ops[0]) in [ast.Eq, ast.NotEq]: + l = a.left = cleanupEquals(a.left) + r = cleanupEquals(a.comparators[0]) + a.comparators = [r] + if type(l) == ast.NameConstant and l.value in [True, False]: + (l,r) = (r,l) + # If we have (boolean expression) == True + if type(r) == ast.NameConstant and r.value in [True, False] and (eventualType(l) == bool): + # Matching types + if (type(a.ops[0]) == ast.Eq and r.value == True) or \ + (type(a.ops[0]) == ast.NotEq and r.value == False): + transferMetaData(a, l) # make sure to keep the original location + return l + else: + tmp = ast.UnaryOp(ast.Not(addedNotOp=True), l) + transferMetaData(a, tmp) + return tmp + else: + return a + else: + return applyToChildren(a, cleanupEquals) + +def cleanupBoolOps(a): + """When possible, combine adjacent boolean expressions""" + """Note- we are assuming that all ops are the first op (as is done in the simplify function)""" + if not isinstance(a, ast.AST): + return a + if type(a) == ast.BoolOp: + allTypesWork = True + for i in range(len(a.values)): + a.values[i] = cleanupBoolOps(a.values[i]) + if eventualType(a.values[i]) != bool or hasattr(a.values[i], "multiComp"): + allTypesWork = False + + # We can't reduce if the types aren't all booleans + if not allTypesWork: + return a + + i = 0 + while i < len(a.values) - 1: + current = a.values[i] + next = a.values[i+1] + # (a and b and c and d) or (a and e and d) == a and ((b and c) or e) and d + if type(current) == type(next) == ast.BoolOp: + if type(current.op) == type(next.op): + minlength = min(len(current.values), len(next.values)) # shortest length + + # First, check for all identical values from the front + j = 0 + while j < minlength: + if compareASTs(current.values[j], next.values[j], checkEquality=True) != 0: + break + j += 1 + + # Same values in both, so get rid of the latter line + if j == len(current.values) == len(next.values): + a.values.pop(i+1) + continue + i += 1 + ### If reduced to one item, just return that item + return a.values[0] if (len(a.values) == 1) else a + return applyToChildren(a, cleanupBoolOps) + +def cleanupRanges(a): + """Remove any range shenanigans, because Python lets you include unneccessary values""" + if not isinstance(a, ast.AST): + return a + if type(a) == ast.Call: + if type(a.func) == ast.Name: + if a.func.id in ["range"]: + if len(a.args) == 3: + # The step defaults to 1! + if type(a.args[2]) == ast.Num and a.args[2].n == 1: + a.args = a.args[:-1] + if len(a.args) == 2: + # The start defaults to 0! + if type(a.args[0]) == ast.Num and a.args[0].n == 0: + a.args = a.args[1:] + return applyToChildren(a, cleanupRanges) + +def cleanupSlices(a): + """Remove any slice shenanigans, because Python lets you include unneccessary values""" + if not isinstance(a, ast.AST): + return a + if type(a) == ast.Subscript: + if type(a.slice) == ast.Slice: + # Lower defaults to 0 + if a.slice.lower != None and type(a.slice.lower) == ast.Num and a.slice.lower.n == 0: + a.slice.lower = None + # Upper defaults to len(value) + if a.slice.upper != None and type(a.slice.upper) == ast.Call and \ + type(a.slice.upper.func) == ast.Name and a.slice.upper.func.id == "len": + if compareASTs(a.value, a.slice.upper.args[0], checkEquality=True) == 0: + a.slice.upper = None + # Step defaults to 1 + if a.slice.step != None and type(a.slice.step) == ast.Num and a.slice.step.n == 1: + a.slice.step = None + return applyToChildren(a, cleanupSlices) + +def cleanupTypes(a): + """Remove any unneccessary type mappings""" + if not isinstance(a, ast.AST): + return a + # No need to cast something if it'll be changed anyway by a binary operation + if type(a) == ast.BinOp: + a.left = cleanupTypes(a.left) + a.right = cleanupTypes(a.right) + # Ints become floats naturally + if eventualType(a.left) == eventualType(a.right) == float: + if type(a.right) == ast.Call and type(a.right.func) == ast.Name and \ + a.right.func.id == "float" and len(a.right.args) == 1 and len(a.right.keywords) == 0 and \ + eventualType(a.right.args[0]) in [int, float]: + a.right = a.right.args[0] + elif type(a.left) == ast.Call and type(a.left.func) == ast.Name and \ + a.left.func.id == "float" and len(a.left.args) == 1 and len(a.left.keywords) == 0 and \ + eventualType(a.left.args[0]) in [int, float]: + a.left = a.left.args[0] + return a + elif type(a) == ast.Call and type(a.func) == ast.Name and len(a.args) == 1 and len(a.keywords) == 0: + a.func = cleanupTypes(a.func) + a.args = [cleanupTypes(a.args[0])] + # If the type already matches, no need to cast it + funName = a.func.id + argType = eventualType(a.args[0]) + if type(a.func) == ast.Name: + if (funName == "float" and argType == float) or \ + (funName == "int" and argType == int) or \ + (funName == "bool" and argType == bool) or \ + (funName == "str" and argType == str): + return a.args[0] + return applyToChildren(a, cleanupTypes) + +def turnPositive(a): + """Take a negative number and make it positive""" + if type(a) == ast.UnaryOp and type(a.op) == ast.USub: + return a.operand + elif type(a) == ast.Num and type(a.n) != complex and a.n < 0: + a.n = a.n * -1 + return a + else: + log("transformations\tturnPositive\tFailure: " + str(a), "bug") + return a + +def isNegative(a): + """Is the give number negative?""" + if type(a) == ast.UnaryOp and type(a.op) == ast.USub: + return True + elif type(a) == ast.Num and type(a.n) != complex and a.n < 0: + return True + else: + return False + +def cleanupNegations(a): + """Remove unneccessary negations""" + if not isinstance(a, ast.AST): + return a + elif type(a) == ast.BinOp: + a.left = cleanupNegations(a.left) + a.right = cleanupNegations(a.right) + + if type(a.op) == ast.Add: + # x + (-y) + if isNegative(a.right): + a.right = turnPositive(a.right) + a.op = ast.Sub(global_id=a.op.global_id, num_negated=True) + return a + # (-x) + y + elif isNegative(a.left): + if couldCrash(a.left) and couldCrash(a.right): + return a # can't switch if it'll change the message + else: + (a.left,a.right) = (a.right,turnPositive(a.left)) + a.op = ast.Sub(global_id=a.op.global_id, num_negated=True) + return a + elif type(a.op) == ast.Sub: + # x - (-y) + if isNegative(a.right): + a.right = turnPositive(a.right) + a.op = ast.Add(global_id=a.op.global_id, num_negated=True) + return a + elif type(a.right) == ast.BinOp: + # x - (y + z) = x + (-y - z) + if type(a.right.op) == ast.Add: + a.right.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.right.left, addedOther=True)) + a.right.op = ast.Sub(global_id=a.right.op.global_id, num_negated=True) + a.op = ast.Add(global_id=a.op.global_id, num_negated=True) + return a + # x - (y - z) = x + (-y + z) = x + (z - y) + elif type(a.right.op) == ast.Sub: + if couldCrash(a.right.left) and couldCrash(a.right.right): + a.right.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.right.left, addedOther=True)) + a.right.op = ast.Add(global_id=a.right.op.global_id, num_negated=True) + a.op = ast.Add(global_id=a.op.global_id, num_negated=True) + return a + else: + (a.right.left, a.right.right) = (a.right.right, a.right.left) + a.op = ast.Add(global_id=a.op.global_id, num_negated=True) + return a + # Move negations to the outer part of multiplications + elif type(a.op) == ast.Mult: + # -x * -y + if isNegative(a.left) and isNegative(a.right): + a.left = turnPositive(a.left) + a.right = turnPositive(a.right) + return a + # -x * y = -(x*y) + elif isNegative(a.left): + if eventualType(a.right) in [int, float]: + a.left = turnPositive(a.left) + return cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a, addedOther=True)) + # x * -y = -(x*y) + elif isNegative(a.right): + if eventualType(a.left) in [int, float]: + a.right = turnPositive(a.right) + return cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a, addedOther=True)) + elif type(a.op) in [ast.Div, ast.FloorDiv]: + if isNegative(a.left) and isNegative(a.right): + a.left = turnPositive(a.left) + a.right = turnPositive(a.right) + return a + return a + elif type(a) == ast.UnaryOp: + a.operand = cleanupNegations(a.operand) + if type(a.op) == ast.USub: + if type(a.operand) == ast.BinOp: + # -(x + y) = -x - y + if type(a.operand.op) == ast.Add: + a.operand.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.operand.left, addedOther=True)) + a.operand.op = ast.Sub(global_id=a.operand.op.global_id, num_negated=True) + transferMetaData(a, a.operand) + return a.operand + # -(x - y) = -x + y = y - x + elif type(a.operand.op) == ast.Sub: + if couldCrash(a.operand.left) and couldCrash(a.operand.right): + a.operand.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.operand.left, addedOther=True)) + a.operand.op = ast.Add(global_id=a.operand.op.global_id, num_negated=True) + transferMetaData(a, a.operand) + return a.operand + else: + (a.operand.left,a.operand.right) = (a.operand.right,a.operand.left) + transferMetaData(a, a.operand) + return a.operand + return a + # Special case for absolute value + elif type(a) == ast.Call and type(a.func) == ast.Name and a.func.id == "abs" and len(a.args) == 1: + a.args[0] = cleanupNegations(a.args[0]) + if type(a.args[0]) == ast.UnaryOp and type(a.args[0].op) == ast.USub: + a.args[0] = a.args[0].operand + elif type(a.args[0]) == ast.BinOp and type(a.args[0].op) == ast.Sub: + if not (couldCrash(a.args[0].left) and couldCrash(a.args[0].right)) and \ + compareASTs(a.args[0].left, a.args[0].right) > 0: + (a.args[0].left,a.args[0].right) = (a.args[0].right,a.args[0].left) + return a + else: + return applyToChildren(a, cleanupNegations) + +### CONDITIONAL TRANSFORMATIONS ### + +def combineConditionals(a): + """When possible, combine conditional branches""" + if not isinstance(a, ast.AST): + return a + elif type(a) == ast.If: + for i in range(len(a.body)): + a.body[i] = combineConditionals(a.body[i]) + for i in range(len(a.orelse)): + a.orelse[i] = combineConditionals(a.orelse[i]) + + # if a: if b: x can be - if a and b: x + if (len(a.orelse) == 0) and (len(a.body) == 1) and \ + (type(a.body[0]) == ast.If) and (len(a.body[0].orelse) == 0): + a.test = ast.BoolOp(ast.And(combinedConditionalOp=True), [a.test, a.body[0].test], combinedConditional=True) + a.body = a.body[0].body + # if a: x elif b: x can be - if a or b: x + elif (len(a.orelse) == 1) and \ + (type(a.orelse[0]) == ast.If) and (len(a.orelse[0].orelse) == 0): + if compareASTs(a.body, a.orelse[0].body, checkEquality=True) == 0: + a.test = ast.BoolOp(ast.Or(combinedConditionalOp=True), [a.test, a.orelse[0].test], combinedConditional=True) + a.orelse = [] + return a + else: + return applyToChildren(a, combineConditionals) + +def staticVars(l, vars): + """Determines whether the given lines change the given variables""" + # First, if one of the variables can be modified, there might be a problem + mutableVars = [] + for var in vars: + if (not (hasattr(var, "type") and (var.type in [int, float, str, bool]))): + mutableVars.append(var) + + for i in range(len(l)): + if type(l[i]) == ast.Assign: + for var in vars: + if var.id in allVariableNamesUsed(l[i].targets[0]): + return False + elif type(l[i]) == ast.AugAssign: + for var in vars: + if var.id in allVariableNamesUsed(l[i].target): + return False + elif type(l[i]) in [ast.If, ast.While]: + if not (staticVars(l[i].body, vars) and staticVars(l[i].orelse, vars)): + return False + elif type(l[i]) == ast.For: + for var in vars: + if var.id in allVariableNamesUsed(l[i].target): + return False + if not (staticVars(l[i].body, vars) and staticVars(l[i].orelse, vars)): + return False + elif type(l[i]) in [ast.FunctionDef, ast.ClassDef, ast.Try, ast.With]: + log("transformations\tstaticVars\tMissing type: " + str(type(l[i])), "bug") + + # If a mutable variable is used, we can't trust it + for var in mutableVars: + if var.id in allVariableNamesUsed(l[i]): + return False + return True + +def getIfBranches(a): + """Gets all the branches of an if statement. Will only work if each else has a single line""" + if type(a) != ast.If: + return None + + if len(a.orelse) == 0: + return [a] + elif len(a.orelse) == 1: + tmp = getIfBranches(a.orelse[0]) + if tmp == None: + return None + return [a] + tmp + else: + return None + +def allVariablesUsed(a): + """Gathers all the variable names used in the ast""" + if type(a) == list: + variables = [] + for x in a: + variables += allVariablesUsed(x) + return variables + + if not isinstance(a, ast.AST): + return [] + elif type(a) == ast.Name: + return [a] + elif type(a) == ast.Assign: + variables = allVariablesUsed(a.value) + for target in a.targets: + if type(target) == ast.Name: + pass + elif type(target) in [ast.Tuple, ast.List]: + for elt in target.elts: + if type(elt) == ast.Name: + pass + else: + variables += allVariablesUsed(elt) + else: + variables += allVariablesUsed(target) + return variables + else: + variables = [] + for child in ast.iter_child_nodes(a): + variables += allVariablesUsed(child) + return variables + +def conditionalRedundancy(a): + """When possible, remove redundant lines from conditionals and combine conditionals.""" + if type(a) == ast.Module: + for i in range(len(a.body)): + if type(a.body[i]) == ast.FunctionDef: + a.body[i] = conditionalRedundancy(a.body[i]) + return a + elif type(a) == ast.FunctionDef: + a.body = conditionalRedundancy(a.body) + return a + + if type(a) == list: + i = 0 + while i < len(a): + stmt = a[i] + if type(stmt) == ast.If: + stmt.body = conditionalRedundancy(stmt.body) + stmt.orelse = conditionalRedundancy(stmt.orelse) + + # If a line appears in both, move it outside the conditionals + if len(stmt.body) > 0 and len(stmt.orelse) > 0 and compareASTs(stmt.body[-1], stmt.orelse[-1], checkEquality=True) == 0: + nextLine = stmt.body[-1] + nextLine.second_global_id = stmt.orelse[-1].global_id + stmt.body = stmt.body[:-1] + stmt.orelse = stmt.orelse[:-1] + stmt.moved_line = nextLine.global_id + # Remove the if statement if both if and else are empty + if len(stmt.body) == 0 and len(stmt.orelse) == 0: + newLine = ast.Expr(stmt.test) + transferMetaData(stmt, newLine) + a[i:i+1] = [newLine, nextLine] + # Switch if and else if if is empty + elif len(stmt.body) == 0: + stmt.test = ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True) + stmt.body = stmt.orelse + stmt.orelse = [] + a[i:i+1] = [stmt, nextLine] + else: + a[i:i+1] = [stmt, nextLine] + continue # skip incrementing so that we check the conditional again + # Join adjacent, disjoint ifs + elif i+1 < len(a) and type(a[i+1]) == ast.If: + branches1 = getIfBranches(stmt) + branches2 = getIfBranches(a[i+1]) + if branches1 != None and branches2 != None: + # First, check whether any vars used in the second set of branches will be changed by the first set + testVars = sum(map(lambda b : allVariablesUsed(b.test), branches2), []) + for branch in branches1: + if not staticVars(branch.body, testVars): + break + else: + branchCombos = [(x,y) for y in branches2 for x in branches1] + for (branch1,branch2) in branchCombos: + if not areDisjoint(branch1.test, branch2.test): + break + else: + # We know the last else branch is empty- fill it with the next tree! + branches1[-1].orelse = [a[i+1]] + a.pop(i+1) + continue # check this conditional again + elif type(stmt) == ast.FunctionDef: + stmt.body = conditionalRedundancy(stmt.body) + elif type(stmt) in [ast.While, ast.For]: + stmt.body = conditionalRedundancy(stmt.body) + stmt.orelse = conditionalRedundancy(stmt.orelse) + elif type(stmt) == ast.ClassDef: + for x in range(len(stmt.body)): + if type(stmt.body[x]) == ast.FunctionDef: + stmt.body[x] = conditionalRedundancy(stmt.body[x]) + elif type(stmt) == ast.Try: + stmt.body = conditionalRedundancy(stmt.body) + for x in range(len(stmt.handlers)): + stmt.handlers[x].body = conditionalRedundancy(stmt.handlers[x].body) + stmt.orelse = conditionalRedundancy(stmt.orelse) + stmt.finalbody = conditionalRedundancy(stmt.finalbody) + elif type(stmt) == ast.With: + stmt.body = conditionalRedundancy(stmt.body) + else: + pass + i += 1 + return a + else: + log("transformations\tconditionalRedundancy\tStrange type: " + str(type(a)), "bug") + +def collapseConditionals(a): + """When possible, combine adjacent conditionals""" + if type(a) == ast.Module: + for i in range(len(a.body)): + if type(a.body[i]) == ast.FunctionDef: + a.body[i] = collapseConditionals(a.body[i]) + return a + elif type(a) == ast.FunctionDef: + a.body = collapseConditionals(a.body) + return a + + if type(a) == list: + l = a + i = len(l) - 1 + + # Go through the lines backwards, since we're collapsing conditionals upwards + while i >= 0: + stmt = l[i] + if type(stmt) == ast.If: + stmt.body = collapseConditionals(stmt.body) + stmt.orelse = collapseConditionals(stmt.orelse) + + # First, check to see if we can collapse across the if and its else + if len(l[i].body) == 1 and len(l[i].orelse) == 1: + ifLine = l[i].body[0] + elseLine = l[i].orelse[0] + + # This only works for Assign and Return + if type(ifLine) == type(elseLine) == ast.Assign and \ + compareASTs(ifLine.targets, elseLine.targets, checkEquality=True) == 0: + pass + elif type(ifLine) == ast.Return and type(elseLine) == ast.Return: + pass + else: + i -= 1 + continue # skip past this + + if type(ifLine.value) == type(elseLine.value) == ast.Name and \ + ifLine.value.id in ['True', 'False'] and elseLine.value.id in ['True', 'False']: + if ifLine.value.id == elseLine.value.id: + # If they both return the same thing, just replace the if altogether. + # But keep the test in case it crashes- we'll remove it later + ifLine.global_id = None # we're replacing the whole if statement + l[i:i+1] = [ast.Expr(l[i].test, addedOther=True, moved_line=ifLine.global_id), ifLine] + elif eventualType(l[i].test) == bool: + testVal = l[i].test + if ifLine.value.id == 'True': + newVal = testVal + else: + newVal = ast.UnaryOp(ast.Not(addedNotOp=True), testVal, negated=True, collapsedExpr=True) + + if type(ifLine) == ast.Assign: + newLine = ast.Assign(ifLine.targets, newVal) + else: + newLine = ast.Return(newVal) + transferMetaData(l[i], newLine) + l[i] = newLine + # Next, check to see if we can collapse across the if and surrounding lines + elif len(l[i].body) == 1 and len(l[i].orelse) == 0: + ifLine = l[i].body[0] + # First, check to see if the current and prior have the same return bodies + if i != 0 and type(l[i-1]) == ast.If and \ + len(l[i-1].body) == 1 and len(l[i-1].orelse) == 0 and \ + type(ifLine) == ast.Return and compareASTs(ifLine, l[i-1].body[0], checkEquality=True) == 0: + # If they do, combine their tests with an Or and get rid of this line + l[i-1].test = ast.BoolOp(ast.Or(combinedConditionalOp=True), [l[i-1].test, l[i].test], combinedConditional=True) + l[i-1].second_global_id = l[i].global_id + l.pop(i) + # Then, check whether the current and latter lines have the same returns + elif i != len(l) - 1 and type(ifLine) == type(l[i+1]) == ast.Return and \ + type(ifLine.value) == type(l[i+1].value) == ast.Name and \ + ifLine.value.id in ['True', 'False'] and l[i+1].value.id in ['True', 'False']: + if ifLine.value.id == l[i+1].value.id: + # No point in keeping the if line- just use the return + l[i] = ast.Expr(l[i].test, addedOther=True) + else: + if eventualType(l[i].test) == bool: + testVal = l[i].test + if ifLine.value.id == 'True': + newLine = ast.Return(testVal) + else: + newLine = ast.Return(ast.UnaryOp(ast.Not(addedNotOp=True), testVal, negated=True, collapsedExpr=True)) + transferMetaData(l[i], newLine) + l[i] = newLine + l.pop(i+1) # get rid of the extra return + elif type(stmt) == ast.FunctionDef: + stmt.body = collapseConditionals(stmt.body) + elif type(stmt) in [ast.While, ast.For]: + stmt.body = collapseConditionals(stmt.body) + stmt.orelse = collapseConditionals(stmt.orelse) + elif type(stmt) == ast.ClassDef: + for i in range(len(stmt.body)): + if type(stmt.body[i]) == ast.FunctionDef: + stmt.body[i] = collapseConditionals(stmt.body[i]) + elif type(stmt) == ast.Try: + stmt.body = collapseConditionals(stmt.body) + for i in range(len(stmt.handlers)): + stmt.handlers[i].body = collapseConditionals(stmt.handlers[i].body) + stmt.orelse = collapseConditionals(stmt.orelse) + stmt.finalbody = collapseConditionals(stmt.finalbody) + elif type(stmt) == ast.With: + stmt.body = collapseConditionals(stmt.body) + else: + pass + i -= 1 + return l + else: + log("transformations\tcollapseConditionals\tStrange type: " + str(type(a)), "bug") |