diff options
Diffstat (limited to 'canonicalize/transformations.py')
-rw-r--r-- | canonicalize/transformations.py | 2775 |
1 files changed, 2775 insertions, 0 deletions
diff --git a/canonicalize/transformations.py b/canonicalize/transformations.py new file mode 100644 index 0000000..7ae1e40 --- /dev/null +++ b/canonicalize/transformations.py @@ -0,0 +1,2775 @@ +import ast, copy, functools + +from .tools import log +from .namesets import * +from .astTools import * + +### VARIABLE ANONYMIZATION ### + +def updateVariableNames(a, varMap, scopeName, randomCounter, imports): + if not isinstance(a, ast.AST): + return + + if type(a) in [ast.FunctionDef, ast.ClassDef]: + if a.name in varMap: + if not hasattr(a, "originalId"): + a.originalId = a.name + a.name = varMap[a.name] + anonymizeStatementNames(a, varMap, "_" + a.name, imports) + elif type(a) == ast.arg: + if a.arg not in varMap and not (builtInName(a.arg) or importedName(a.arg, imports)): + log("Can't assign to arg?", "bug") + if a.arg in varMap: + if not hasattr(a, "originalId"): + a.originalId = a.arg + if varMap[a.arg][0] == "r": + a.randomVar = True # so we know it can crash + if a.arg == varMap[a.arg]: + # Check whether this is a given name + if not isAnonVariable(varMap[a.arg]): + a.dontChangeName = True + a.arg = varMap[a.arg] + elif type(a) == ast.Name: + if a.id not in varMap and not (builtInName(a.id) or importedName(a.id, imports)): + varMap[a.id] = "r" + str(randomCounter[0]) + scopeName + randomCounter[0] += 1 + if a.id in varMap: + if not hasattr(a, "originalId"): + a.originalId = a.id + if varMap[a.id][0] == "r": + a.randomVar = True # so we know it can crash + if a.id == varMap[a.id]: + # Check whether this is a given name + if not isAnonVariable(varMap[a.id]): + a.dontChangeName = True + a.id = varMap[a.id] + else: + for child in ast.iter_child_nodes(a): + updateVariableNames(child, varMap, scopeName, randomCounter, imports) + +def gatherLocalScope(a, globalMap, scopeName, imports, goBackwards=False): + localMap = { } + varLetter = "g" if type(a) == ast.Module else "v" + paramCounter = 0 + localCounter = 0 + if type(a) == ast.FunctionDef: + for param in a.args.args: + if type(param) == ast.arg: + if not (builtInName(param.arg) or importedName(param.arg, imports)) and param.arg not in localMap and param.arg not in globalMap: + localMap[param.arg] = "p" + str(paramCounter) + scopeName + paramCounter += 1 + else: + log("transformations\tanonymizeFunctionNames\tWeird parameter type: " + str(type(param)), "bug") + if goBackwards: + items = a.body[::-1] # go through this backwards to get the used function names + else: + items = a.body[:] + # First, go through and create the globalMap + while len(items) > 0: + item = items[0] + if type(item) in [ast.FunctionDef, ast.ClassDef]: + if not (builtInName(item.name) or importedName(item.name, imports)): + if item.name not in localMap and item.name not in globalMap: + localMap[item.name] = "helper_" + varLetter + str(localCounter) + scopeName + localCounter += 1 + else: + item.dontChangeName = True + + # If there are any variables in this node, find and label them + if type(item) in [ ast.Assign, ast.AugAssign, ast.For, ast.With, + ast.Lambda, ast.comprehension, ast.ExceptHandler]: + assns = [] + if type(item) == ast.Assign: + assns = item.targets + elif type(item) == ast.AugAssign: + assns = [item.target] + elif type(item) == ast.For: + assns = [item.target] + elif type(item) == ast.With: + if item.optional_vars != None: + assns = [item.optional_vars] + elif type(item) == ast.Lambda: + assns = item.args.args + elif type(item) == ast.comprehension: + assns = [item.target] + elif type(item) == ast.ExceptHandler: + if item.name != None: + assns = [item.name] + + for assn in assns: + if type(assn) == ast.Name: + if not (builtInName(assn.id) or importedName(assn.id, imports)): + if assn.id not in localMap and assn.id not in globalMap: + localMap[assn.id] = varLetter + str(localCounter) + scopeName + localCounter += 1 + + if type(item) in [ast.For, ast.While, ast.If]: + items += item.body + item.orelse + elif type(item) in [ast.With, ast.ExceptHandler]: + items += item.body + elif type(item) == ast.Try: + items += item.body + item.handlers + item.orelse + item.finalbody + items = items[1:] + return localMap + +def anonymizeStatementNames(a, globalMap, scopeName, imports, goBackwards=False): + """Gather the local variables, then update variable names in each line""" + localMap = gatherLocalScope(a, globalMap, scopeName, imports, goBackwards=goBackwards) + varMap = { } + varMap.update(globalMap) + varMap.update(localMap) + randomCounter = [0] + functionsSeen = [] + if type(a) == ast.FunctionDef: + for arg in a.args.args: + updateVariableNames(arg, varMap, scopeName, randomCounter, imports) + for line in a.body: + updateVariableNames(line, varMap, scopeName, randomCounter, imports) + +def anonymizeNames(a, namesToKeep, imports): + """Anonymize all of variables/names that occur in the given AST""" + """If we run this on an anonymized AST, it will fix the names again to get rid of any gaps!""" + if type(a) != ast.Module: + return a + globalMap = { } + for var in namesToKeep: + globalMap[var] = var + anonymizeStatementNames(a, globalMap, "", imports, goBackwards=True) + return a + +def propogateNameMetadata(a, namesToKeep, imports): + """Propogates name metadata through a state. We assume that the names are all properly formatted""" + if type(a) == list: + for child in a: + child = propogateNameMetadata(child, namesToKeep, imports) + return a + elif not isinstance(a, ast.AST): + return a + if type(a) == ast.Name: + if (builtInName(a.id) or importedName(a.id, imports)): + pass + elif a.id in namesToKeep: + a.dontChangeName = True + else: + if not hasattr(a, "originalId"): + a.originalId = a.id + if not isAnonVariable(a.id): + a.dontChangeName = True # it's a name we shouldn't mess with + elif type(a) == ast.arg: + if (builtInName(a.arg) or importedName(a.arg, imports)): + pass + elif a.arg in namesToKeep: + a.dontChangeName = True + else: + if not hasattr(a, "originalId"): + a.originalId = a.arg + if not isAnonVariable(a.arg): + a.dontChangeName = True # it's a name we shouldn't mess with + for child in ast.iter_child_nodes(a): + child = propogateNameMetadata(child, namesToKeep, imports) + return a + +### HELPER FOLDING ### + +def findHelperFunction(a, helperId, helperCount): + """Finds the first helper function used in the ast""" + if not isinstance(a, ast.AST): + return None + + # Check all the children, so that we don't end up with a recursive problem + for child in ast.iter_child_nodes(a): + f = findHelperFunction(child, helperId, helperCount) + if f != None: + return f + # Then check if this is the right call + if type(a) == ast.Call: + if type(a.func) == ast.Name and a.func.id == helperId: + if helperCount[0] > 0: + helperCount[0] -= 1 + else: + return a + return None + +def individualizeVariables(a, variablePairs, idNum, imports): + """Replace variable names with new individualized ones (for inlining methods)""" + if not isinstance(a, ast.AST): + return + + if type(a) == ast.Name: + if a.id not in variablePairs and not (builtInName(a.id) or importedName(a.id, imports)): + name = "_var_" + a.id + "_" + str(idNum[0]) + variablePairs[a.id] = name + if a.id in variablePairs: + a.id = variablePairs[a.id] + # Override built-in names when they're assigned to. + elif type(a) == ast.Assign and type(a.targets[0]) == ast.Name: + if a.targets[0].id not in variablePairs: + name = "_var_" + a.targets[0].id + "_" + str(idNum[0]) + variablePairs[a.targets[0].id] = name + elif type(a) == ast.arguments: + for arg in a.args: + if type(arg) == ast.arg: + name = "_arg_" + arg.arg + "_" + str(idNum[0]) + variablePairs[arg.arg] = name + arg.arg = variablePairs[arg.arg] + return + elif type(a) == ast.Call: + if type(a.func) == ast.Name: + variablePairs[a.func.id] = a.func.id # save the function name! + + for child in ast.iter_child_nodes(a): + individualizeVariables(child, variablePairs, idNum, imports) + +def replaceAst(a, old, new, shouldReplace): + """Replace the old value with the new one, but only once! That's what the shouldReplace variable is for.""" + if not isinstance(a, ast.AST) or not shouldReplace[0]: + return a + elif compareASTs(a, old, checkEquality=True) == 0: + shouldReplace[0] = False + return new + return applyToChildren(a, lambda x: replaceAst(x, old, new, shouldReplace)) + +def mapHelper(a, helper, idNum, imports): + """Map the helper function into this function, if it's used. idNum gives us the current id""" + if type(a) in [ast.FunctionDef, ast.ClassDef]: + a.body = mapHelper(a.body, helper, idNum, imports) + elif type(a) in [ast.For, ast.While, ast.If]: + a.body = mapHelper(a.body, helper, idNum, imports) + a.orelse = mapHelper(a.orelse, helper, idNum, imports) + if type(a) != list: + return a # we only deal with lists + + i = 0 + body = a + while i < len(body): + if type(body[i]) in [ast.FunctionDef, ast.ClassDef, ast.For, ast.While, ast.If]: + body[i] = mapHelper(body[i], helper, idNum, imports) # deal with blocks first + # While we still need to replace a function being called + helperCount = 0 + while countVariables(body[i], helper.name) > helperCount: + callExpression = findHelperFunction(body[i], helper.name, [helperCount]) + if callExpression == None: + break + idNum[0] += 1 + variablePairs = {} + + # First, update the method's variables + methodArgs = deepcopy(helper.args) + individualizeVariables(methodArgs, variablePairs, idNum, imports) + methodLines = deepcopyList(helper.body) + for j in range(len(methodLines)): + individualizeVariables(methodLines[j], variablePairs, idNum, imports) + + # Then, move each of the parameters into an assignment + argLines = [] + badArgs = False + for j in range(len(callExpression.args)): + arg = methodArgs.args[j] + value = callExpression.args[j] + # If it only occurs once, and in the return statement.. + if countVariables(methodLines, arg.arg) == 1 and countVariables(methodLines[-1], arg.arg) == 1: + var = ast.Name(arg.arg, ast.Load()) + transferMetaData(arg, var) + methodLines[-1] = replaceAst(methodLines[-1], var, value, [True]) + elif couldCrash(value): + badArgs = True + break + else: + if type(arg) == ast.arg: + var = ast.Name(arg.arg, ast.Store(), helperVar=True) + transferMetaData(arg, var) + varAssign = ast.Assign([var], value, global_id=arg.global_id, helperParamAssign=True) + argLines.append(varAssign) + else: # what other type would it be? + log("transformations\tmapHelper\tWeird arg: " + str(type(arg)), "bug") + if badArgs: + helperCount += 1 + continue + + # Finally, carry over the final value into the correct position + if countOccurances(helper, ast.Return) == 1: + returnLine = [replaceAst(body[i], callExpression, methodLines[-1].value, [True])] + methodLines.pop(-1) + else: + comp = compareASTs(body[i].value, callExpression, checkEquality=True) == 0 + if type(body[i]) == ast.Return and comp: + # In the case of a return, we can return None, it's cool + callVal = ast.NameConstant(None, helperReturnVal=True) + transferMetaData(callExpression, callVal) + returnVal = ast.Return(callVal, global_id=body[i].global_id, helperReturnAssign=True) + returnLine = [returnVal] + elif type(body[i]) == ast.Assign and comp: + # Assigns can be assigned to None + callVal = ast.NameConstant(None, helperReturnVal=True) + transferMetaData(callExpression, callVal) + assignVal = ast.Assign(body[i].targets, callVal, global_id=body[i].global_id, helperReturnAssign=True) + returnLine = [assignVal] + elif type(body[i]) == ast.Expr and comp: + # Delete the line if it's by itself + returnLine = [] + else: + # Otherwise, what is going on?!? + returnLine = [] + log("transformations\tmapHelper\tWeird removeCall: " + str(type(body[i])), "bug") + #log("transformations\tmapHelper\tMapped helper function: " + helper.name, "bug") + body[i:i+1] = argLines + methodLines + returnLine + i += 1 + return a + +def mapVariable(a, varId, assn): + """Map the variable assignment into the function, if it's needed""" + if type(a) != ast.FunctionDef: + return a + + for arg in a.args.args: + if arg.arg == varId: + return a # overriden by local variable + for i in range(len(a.body)): + line = a.body[i] + if type(line) == ast.Assign: + for target in line.targets: + if type(target) == ast.Name and target.id == varId: + break + elif type(target) in [ast.Tuple, ast.List]: + for elt in target.elts: + if type(elt) == ast.Name and elt.id == varId: + break + if countVariables(line, varId) > 0: + a.body[i:i+1] = [deepcopy(assn), line] + break + return a + +def helperFolding(a, mainFun, imports): + """When possible, fold the functions used in this module into their callers""" + if type(a) != ast.Module: + return a + + globalCounter = [0] + body = a.body + i = 0 + while i < len(body): + item = body[i] + if type(item) == ast.FunctionDef: + if item.name != mainFun: + # We only want non-recursive functions that have a single return which occurs at the end + # Also, we don't want any functions with parameters that are changed during the function + returnOccurs = countOccurances(item, ast.Return) + if countVariables(item, item.name) == 0 and returnOccurs <= 1 and type(item.body[-1]) == ast.Return: + allArgs = [] + for arg in item.args.args: + allArgs.append(arg.arg) + for tmpA in ast.walk(item): + if type(tmpA) in [ast.Assign, ast.AugAssign]: + allGood = True + assignedIds = gatherAssignedVarIds(tmpA.targets if type(tmpA) == ast.Assign else [tmpA.target]) + for arg in allArgs: + if arg in assignedIds: + allGood = False + break + if not allGood: + break + else: + for j in range(len(item.body)-1): + if couldCrash(item.body[j]): + break + else: + # If we satisfy these requirements, translate the body of the function into all functions that call it + gone = True + used = False + for j in range(len(body)): + if i != j and type(body[j]) == ast.FunctionDef: + if countVariables(body[j], item.name) > 0: + used = True + mapHelper(body[j], item, globalCounter, imports) + if countVariables(body[j], item.name) > 0: + gone = False + if used and gone: + body.pop(i) + continue + elif type(item) == ast.Assign: + # Is it ever changed in the global area? + if len(item.targets) == 1 and type(item.targets[0]) == ast.Name: + if eventualType(item.value) in [int, float, bool, str]: # if it isn't mutable + for j in range(i+1, len(body)): + line = body[j] + if type(line) == ast.FunctionDef: + if countOccurances(line, ast.Global) > 0: # TODO: improve this + break + if item.targets[0].id in getAllAssignedVarIds(line): + break + else: # if in scope + if countVariables(line, item.targets[0].id) > 0: + break + else: + # Variable never appears again- we can map it! + for j in range(i+1, len(body)): + line = body[j] + if type(line) == ast.FunctionDef: + mapVariable(line, item.targets[0].id, item) + body.pop(i) + continue + i += 1 + return a + +### AST PREPARATION ### + +def listNotEmpty(a): + """Determines that the iterable is NOT empty, if we can know that""" + """Used for For objects""" + if not isinstance(a, ast.AST): + return False + if type(a) == ast.Call: + if type(a.func) == ast.Name and a.func.id in ["range"]: + if len(a.args) == 1: # range(x) + return type(a.args[0]) == ast.Num and type(a.args[0].n) != complex and a.args[0].n > 0 + elif len(a.args) == 2: # range(start, x) + if type(a.args[0]) == ast.Num and type(a.args[1]) == ast.Num and \ + type(a.args[0].n) != complex and type(a.args[1].n) != complex and \ + a.args[0].n < a.args[1].n: + return True + elif type(a.args[1]) == ast.BinOp and type(a.args[1].op) == ast.Add: + if type(a.args[1].right) == ast.Num and type(a.args[1].right) != complex and a.args[1].right.n > 0 and \ + compareASTs(a.args[0], a.args[1].left, checkEquality=True) == 0: + return True + elif type(a.args[1].left) == ast.Num and type(a.args[1].left) != complex and a.args[1].left.n > 0 and \ + compareASTs(a.args[0], a.args[1].right, checkEquality=True) == 0: + return True + elif type(a) in [ast.List, ast.Tuple]: + return len(a.elts) > 0 + elif type(a) == ast.Str: + return len(a.s) > 0 + return False + +def simplifyUpdateId(var, variableMap, idNum): + """Update the varID of a new variable""" + if type(var) not in [ast.Name, ast.arg]: + return var + idVar = var.id if type(var) == ast.Name else var.arg + if not hasattr(var, "varID"): + if idVar in variableMap: + var.varID = variableMap[idVar][1] + else: + var.varID = idNum[0] + idNum[0] += 1 + +def simplify_multicomp(a): + if type(a) == ast.Compare and len(a.ops) > 1: + # Only do one comparator at a time. If we don't do this, things get messy! + comps = [a.left] + a.comparators + values = [ ] + # Compare each of the pairs + for i in range(len(a.ops)): + if i > 0: + # Label all nodes as middle parts so we can recognize them later + assignPropertyToAll(comps[i], "multiCompMiddle") + values.append(ast.Compare(comps[i], [a.ops[i]], [deepcopy(comps[i+1])], multiCompPart=True)) + # Combine comparisons with and operators + boolOp = ast.And(multiCompOp=True) + boolopVal = ast.BoolOp(boolOp, values, multiComp=True, global_id=a.global_id) + return boolopVal + return a + +def simplify(a): + """This function simplifies the usual Python AST to make it usable by our functions.""" + if not isinstance(a, ast.AST): + return a + elif type(a) == ast.Assign: + if len(a.targets) > 1: + # Go through all targets and assign them on separate lines + lines = [ast.Assign([a.targets[-1]], a.value, global_id=a.global_id)] + for i in range(len(a.targets)-1, 0, -1): + t = a.targets[i] + if type(t) == ast.Name: + loadedTarget = ast.Name(t.id, ast.Load()) + elif type(t) == ast.Subscript: + loadedTarget = ast.Subscript(deepcopy(t.value), deepcopy(t.slice), ast.Load()) + elif type(t) == ast.Attribute: + loadedTarget = ast.Attribute(deepcopy(t.value), t.attr, ast.Load()) + elif type(t) == ast.Tuple: + loadedTarget = ast.Tuple(deepcopy(t.elts), ast.Load()) + elif type(t) == ast.List: + loadedTarget = ast.List(deepcopy(t.elts), ast.Load()) + else: + log("transformations\tsimplify\tOdd loadedTarget: " + str(type(t)), "bug") + transferMetaData(t, loadedTarget) + loadedTarget.global_id = t.global_id + + lines.append(ast.Assign([a.targets[i-1]], loadedTarget, global_id=a.global_id)) + else: + lines = [a] + + i = 0 + while i < len(lines): + # For each line, figure out type and varID + lines[i].value = simplify(lines[i].value) + t = lines[i].targets[0] + if type(t) in [ast.Tuple, ast.List]: + val = lines[i].value + # If the items are being assigned separately, with no dependance on each other, + # separate out the elements of the tuple + if type(val) in [ast.Tuple, ast.List] and len(t.elts) == len(val.elts) and \ + allVariableNamesUsed(val) == []: + listLines = [] + for j in range(len(t.elts)): + assignVal = ast.Assign([t.elts[j]], val.elts[j], global_id = lines[i].global_id) + listLines += simplify(assignVal) + lines[i:i+1] = listLines + i += len(listLines) - 1 + i += 1 + return lines + elif type(a) == ast.AugAssign: + # Turn all AugAssigns into Assigns + a.target = simplify(a.target) + if eventualType(a.target) not in [bool, int, str, float]: + # Can't get rid of AugAssign, in case the += is different + a.value = simplify(a.value) + return a + if type(a.target) == ast.Name: + loadedTarget = ast.Name(a.target.id, ast.Load()) + elif type(a.target) == ast.Subscript: + loadedTarget = ast.Subscript(deepcopy(a.target.value), deepcopy(a.target.slice), ast.Load()) + elif type(a.target) == ast.Attribute: + loadedTarget = ast.Attribute(deepcopy(a.target.value), a.target.attr, ast.Load()) + elif type(a.target) == ast.Tuple: + loadedTarget = ast.Tuple(deepcopy(a.target.elts), ast.Load()) + elif type(a.target) == ast.List: + loadedTarget = ast.List(deepcopy(a.target.elts), ast.Load()) + else: + log("transformations\tsimplify\tOdd AugAssign target: " + str(type(a.target)), "bug") + transferMetaData(a.target, loadedTarget) + loadedTarget.global_id = a.target.global_id + a.target.augAssignVal = True # for later recognition + loadedTarget.augAssignVal = True + assignVal = ast.Assign([a.target], ast.BinOp(loadedTarget, a.op, a.value, augAssignBinOp=True), global_id=a.global_id) + return simplify(assignVal) + elif type(a) == ast.Compare and len(a.ops) > 1: + return simplify(simplify_multicomp(a)) + return applyToChildren(a, lambda x : simplify(x)) + +def propogateMetadata(a, argTypes, variableMap, idNum): + """This function propogates metadata about type throughout the AST. + argTypes lets us institute types for each function's args + variableMap maps variable ids to their types and id numbers + idNum gives us a global number to use for variable ids""" + if not isinstance(a, ast.AST): + return a + elif type(a) == ast.FunctionDef: + if a.name in argTypes: + theseTypes = argTypes[a.name] + else: + theseTypes = [] + variableMap = copy.deepcopy(variableMap) # variables shouldn't affect variables outside + idNum = copy.deepcopy(idNum) + for i in range(len(a.args.args)): + arg = a.args.args[i] + if type(arg) == ast.arg: + simplifyUpdateId(arg, variableMap, idNum) + # Match the args if possible + if len(a.args.args) == len(theseTypes) and \ + theseTypes[i] in ["int", "str", "float", "bool", "list"]: + arg.type = eval(theseTypes[i]) + variableMap[arg.arg] = (arg.type, arg.varID) + else: + arg.type = None + variableMap[arg.arg] = (None, arg.varID) + else: + log("transformations\tpropogateMetadata\tWeird type in args: " + str(type(arg)), "bug") + + newBody = [] + for line in a.body: + newLine = propogateMetadata(line, argTypes, variableMap, idNum) + if type(newLine) == list: + newBody += newLine + else: + newBody.append(newLine) + a.body = newBody + return a + elif type(a) == ast.Assign: + val = a.value = propogateMetadata(a.value, argTypes, variableMap, idNum) + if len(a.targets) == 1: + t = a.targets[0] + if type(t) == ast.Name: + simplifyUpdateId(t, variableMap, idNum) + varType = eventualType(val) + t.type = varType + variableMap[t.id] = (varType, t.varID) # update type + elif type(t) in [ast.Tuple, ast.List]: + # If the items are being assigned separately, with no dependance on each other, + # assign the appropriate types + getTypes = type(val) in [ast.Tuple, ast.List] and len(t.elts) == len(val.elts) and len(allVariableNamesUsed(val)) == 0 + for j in range(len(t.elts)): + t2 = t.elts[j] + givenType = eventualType(val.elts[j]) if getTypes else None + if type(t2) == ast.Name: + simplifyUpdateId(t2, variableMap, idNum) + t2.type = givenType + variableMap[t2.id] = (givenType, t2.varID) + elif type(t.elts[j]) in [ast.Subscript, ast.Attribute, ast.Tuple, ast.List]: + pass + else: + log("transformations\tpropogateMetadata\tOdd listTarget: " + str(type(t.elts[j])), "bug") + return a + elif type(a) == ast.AugAssign: + # Turn all AugAssigns into Assigns + t = a.target = propogateMetadata(a.target, argTypes, variableMap, idNum) + a.value = propogateMetadata(a.value, argTypes, variableMap, idNum) + if type(t) == ast.Name: + if eventualType(a.target) not in [bool, int, str, float]: + finalType = None + else: + if type(a.target) == ast.Name: + loadedTarget = ast.Name(a.target.id, ast.Load()) + elif type(a.target) == ast.Subscript: + loadedTarget = ast.Subscript(deepcopy(a.target.value), deepcopy(a.target.slice), ast.Load()) + elif type(a.target) == ast.Attribute: + loadedTarget = ast.Attribute(deepcopy(a.target.value), a.target.attr, ast.Load()) + elif type(a.target) == ast.Tuple: + loadedTarget = ast.Tuple(deepcopy(a.target.elts), ast.Load()) + elif type(a.target) == ast.List: + loadedTarget = ast.List(deepcopy(a.target.elts), ast.Load()) + else: + log("transformations\tsimplify\tOdd AugAssign target: " + str(type(a.target)), "bug") + transferMetaData(a.target, loadedTarget) + actualValue = ast.BinOp(loadedTarget, a.op, a.value) + finalType = eventualType(actualValue) + + simplifyUpdateId(t, variableMap, idNum) + varType = finalType + t.type = varType + variableMap[t.id] = (varType, t.varID) # update type + return a + elif type(a) == ast.For: # START HERE + a.iter = propogateMetadata(a.iter, argTypes, variableMap, idNum) + if type(a.target) == ast.Name: + simplifyUpdateId(a.target, variableMap, idNum) + ev = eventualType(a.iter) + if ev == str: + a.target.type = str + # we know ranges are made of ints + elif type(a.iter) == ast.Call and type(a.iter.func) == ast.Name and (a.iter.func.id) == "range": + a.target.type = int + else: + a.target.type = None + variableMap[a.target.id] = (a.target.type, a.target.varID) + a.target = propogateMetadata(a.target, argTypes, variableMap, idNum) + + # Go through the body and orelse to map out more variables + body = [] + bodyMap = copy.deepcopy(variableMap) + for i in range(len(a.body)): + result = propogateMetadata(a.body[i], argTypes, bodyMap, idNum) + if type(result) == list: + body += result + else: + body.append(result) + a.body = body + + orelse = [] + orelseMap = copy.deepcopy(variableMap) + for var in bodyMap: + if var not in orelseMap: + orelseMap[var] = (42, bodyMap[var][1]) + for i in range(len(a.orelse)): + result = propogateMetadata(a.orelse[i], argTypes, orelseMap, idNum) + if type(result) == list: + orelse += result + else: + orelse.append(result) + a.orelse = orelse + + keys = list(variableMap.keys()) + for key in keys: # reset types of changed keys + if key not in bodyMap or bodyMap[key] != variableMap[key] or \ + key not in orelseMap or orelseMap[key] != variableMap[key]: + variableMap[key] = (None, variableMap[key][1]) + if countOccurances(a.body, ast.Break) == 0: # We will definitely enter the else! + for key in orelseMap: + # If we KNOW it will be this type + if key not in bodyMap or bodyMap[key] == orelseMap[key]: + variableMap[key] = orelseMap[key] + + # If we KNOW it will run at least once + if listNotEmpty(a.iter): + for key in bodyMap: + if key in variableMap: + continue + # type might be changed + elif key in orelseMap and orelseMap[key] != bodyMap[key]: + continue + variableMap[key] = bodyMap[key] + return a + elif type(a) == ast.While: + body = [] + bodyMap = copy.deepcopy(variableMap) + for i in range(len(a.body)): + result = propogateMetadata(a.body[i], argTypes, bodyMap, idNum) + if type(result) == list: + body += result + else: + body.append(result) + a.body = body + + orelse = [] + orelseMap = copy.deepcopy(variableMap) + for var in bodyMap: + if var not in orelseMap: + orelseMap[var] = (None, bodyMap[var][1]) + for i in range(len(a.orelse)): + result = propogateMetadata(a.orelse[i], argTypes, orelseMap, idNum) + if type(result) == list: + orelse += result + else: + orelse.append(result) + a.orelse = orelse + + keys = list(variableMap.keys()) + for key in keys: + if key not in bodyMap or bodyMap[key] != variableMap[key] or \ + key not in orelseMap or orelseMap[key] != variableMap[key]: + variableMap[key] = (None, variableMap[key][1]) + + a.test = propogateMetadata(a.test, argTypes, variableMap, idNum) + return a + elif type(a) == ast.If: + a.test = propogateMetadata(a.test, argTypes, variableMap, idNum) + variableMap1 = copy.deepcopy(variableMap) + variableMap2 = copy.deepcopy(variableMap) + + body = [] + for i in range(len(a.body)): + result = propogateMetadata(a.body[i], argTypes, variableMap1, idNum) + if type(result) == list: + body += result + else: + body.append(result) + a.body = body + + for var in variableMap1: + if var not in variableMap2: + variableMap2[var] = (42, variableMap1[var][1]) + + orelse = [] + for i in range(len(a.orelse)): + result = propogateMetadata(a.orelse[i], argTypes, variableMap2, idNum) + if type(result) == list: + orelse += result + else: + orelse.append(result) + a.orelse = orelse + + variableMap.clear() + for key in variableMap1: + if key in variableMap2: + if variableMap1[key] == variableMap2[key]: + variableMap[key] = variableMap1[key] + elif variableMap1[key][1] != variableMap2[key][1]: + log("transformations\tsimplify\tvarId mismatch", "bug") + else: # type mismatch + if variableMap2[key][0] == 42: + variableMap[key] = variableMap1[key] + else: + variableMap[key] = (None, variableMap1[key][1]) + elif len(a.orelse) > 0 and type(a.orelse[-1]) == ast.Return: # if the else exits, it doesn't matter + variableMap[key] = variableMap1[key] + for key in variableMap2: + if key not in variableMap1 and len(a.body) > 0 and type(a.body[-1]) == ast.Return: # if the if exits, it doesn't matter + variableMap[key] = variableMap2[key] + return a + elif type(a) == ast.Name: + if a.id in variableMap and variableMap[a.id][0] != 42: + a.type = variableMap[a.id][0] + a.varID = variableMap[a.id][1] + else: + if a.id in variableMap: + a.varID = variableMap[a.id][1] + a.ctx = propogateMetadata(a.ctx, argTypes, variableMap, idNum) + return a + elif type(a) == ast.AugLoad: + return ast.Load() + elif type(a) == ast.AugStore: + return ast.Store() + return applyToChildren(a, lambda x : propogateMetadata(x, argTypes, variableMap, idNum)) + +### SIMPLIFYING FUNCTIONS ### + +def applyTransferLambda(x): + """Simplify an expression by applying constant folding, re-formatting to an AST, and then tranferring the metadata appropriately.""" + if x == None: + return x + tmp = astFormat(constantFolding(x)) + if hasattr(tmp, "global_id") and hasattr(x, "global_id") and tmp.global_id != x.global_id: + return tmp # don't do the transfer, this already has its own metadata + else: + transferMetaData(x, tmp) + return tmp + +def constantFolding(a): + """In constant folding, we evaluate all constant expressions instead of doing operations at runtime""" + if not isinstance(a, ast.AST): + return a + t = type(a) + if t in [ast.FunctionDef, ast.ClassDef]: + for i in range(len(a.body)): + a.body[i] = applyTransferLambda(a.body[i]) + return a + elif t in [ast.Import, ast.ImportFrom, ast.Global]: + return a + elif t == ast.BoolOp: + # Condense the boolean's values + newValues = [] + ranks = [] + count = 0 + for val in a.values: + # Condense the boolean operations into one line, if possible + c = constantFolding(val) + if type(c) == ast.BoolOp and type(c.op) == type(a.op) and not hasattr(c, "multiComp"): + newValues += c.values + ranks.append(range(count,count+len(c.values))) + count += len(c.values) + else: + newValues.append(c) + ranks.append(count) + count += 1 + + # Or breaks with True, And breaks with False + breaks = (type(a.op) == ast.Or) + + # Remove the opposite values IF removing them won't mess up the type. + i = len(newValues) - 1 + while i > 0: + if (newValues[i] == (not breaks)) and eventualType(newValues[i-1]) == bool: + newValues.pop(i) + i -= 1 + + if len(newValues) == 0: + # There's nothing to evaluate + return (not breaks) + elif len(newValues) == 1: + # If we're down to one value, just return it! + return newValues[0] + elif newValues[0] == breaks: + # If the first value breaks it, done! + return breaks + elif newValues.count(breaks) >= 1: + # We don't need any values that occur after a break + i = newValues.index(breaks) + newValues = newValues[:i+1] + for i in range(len(newValues)): + newValues[i] = astFormat(newValues[i]) + # get the corresponding value + if i in ranks: + transferMetaData(a.values[ranks.index(i)], newValues[i]) + else: # it's in a list + for j in range(len(ranks)): + if type(ranks[j]) == list and i in ranks[j]: + transferMetaData(a.values[j].values[ranks[j].index(i)], newValues[i]) + break + a.values = newValues + return a + elif t == ast.BinOp: + l = constantFolding(a.left) + r = constantFolding(a.right) + # Hack to make hint chaining work- don't constant-fold filler strings! + if containsTokenStepString(l) or containsTokenStepString(r): + a.left = applyTransferLambda(a.left) + a.right = applyTransferLambda(a.right) + return a + if type(l) in builtInTypes and type(r) in builtInTypes: + try: + val = doBinaryOp(a.op, l, r) + if type(val) == float and val % 0.0001 != 0: # don't deal with trailing floats + pass + else: + tmp = astFormat(val) + transferMetaData(a, tmp) + return tmp + except: + # We have some kind of divide-by-zero issue. + # Therefore, don't calculate it! + pass + if type(l) in builtInTypes: + if type(r) == bool: + r = int(r) + # Commutative operations + elif type(r) == ast.BinOp and type(r.op) == type(a.op) and type(a.op) in [ast.Add, ast.Mult, ast.BitOr, ast.BitAnd, ast.BitXor]: + rLeft = constantFolding(r.left) + if type(rLeft) in builtInTypes: + try: + newLeft = astFormat(doBinaryOp(a.op, l, rLeft)) + transferMetaData(r.left, newLeft) + return ast.BinOp(newLeft, a.op, r.right) + except Exception as e: + pass + + # Empty string is often unneccessary + if type(l) == str and l == '': + if type(a.op) == ast.Add and eventualType(r) == str: + return r + elif type(a.op) == ast.Mult and eventualType(r) == int: + return '' + elif type(l) == bool: + l = int(l) + # 0 is often unneccessary + if l == 0 and eventualType(r) in [int, float]: + if type(a.op) in [ast.Add, ast.BitOr]: + # If it won't change the type + if type(l) == int or eventualType(r) == float: + return r + elif type(l) == float: # Cast it + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [r], []) + elif type(a.op) == ast.Sub: + tmpR = astFormat(r) + transferMetaData(a.right, tmpR) + newR = ast.UnaryOp(ast.USub(addedOtherOp=True), tmpR, addedOther=True) + if type(l) == int or eventualType(r) == float: + return newR + elif type(l) == float: + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [newR], []) + elif type(a.op) in [ast.Mult, ast.LShift, ast.RShift]: + # If either is a float, it's 0 + return 0.0 if float in [eventualType(r), type(l)] else 0 + elif type(a.op) in [ast.Div, ast.FloorDiv, ast.Mod]: + # Check if the right might be zero + if type(r) in builtInTypes and r != 0: + return 0.0 if float in [eventualType(r), type(l)] else 0 + # Same for 1 + elif l == 1: + if type(a.op) == ast.Mult and eventualType(r) in [int, float]: + if type(l) == int or eventualType(r) == float: + return r + elif type(l) == float: + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [r], []) + # No reason to make this a float if the other value has already been cast + elif type(l) == float and l == int(l): + if type(a.op) in [ast.Add, ast.Sub, ast.Mult, ast.Div] and eventualType(r) == float: + l = int(l) + # Some of the same operations are done with the right, but not all of them + if type(r) in builtInTypes: + if type(r) == str and r == '': + if type(a.op) == ast.Add and eventualType(l) == str: + return l + elif type(a.op) == ast.Mult and eventualType(l) == int: + return '' + elif type(r) == bool: + r = int(r) + else: + if r == 0 and eventualType(l) in [int, float]: + if type(a.op) in [ast.Add, ast.Sub, ast.LShift, ast.RShift, ast.BitOr]: + if type(r) == int or eventualType(l) == float: + return l + elif type(r) == float: + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [l], []) + elif type(a.op) == ast.Mult: + return 0.0 if float in [eventualType(l), type(r)] else 0 + elif r == 1: + if type(a.op) in [ast.Mult, ast.Div, ast.Pow] and eventualType(l) in [int, float]: + if type(r) == int or eventualType(l) == float: + return l + elif type(r) == float: + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [l], []) + elif type(a.op) == ast.FloorDiv and eventualType(l) == int: + if eventualType( r ) == int: + return l + elif eventualType( r ) == float: + return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [l], []) + elif type(r) == float and r == int(r): + if type(a.op) in [ast.Add, ast.Sub, ast.Mult, ast.Div] and eventualType(l) == float: + r = int(r) + a.left = applyTransferLambda(a.left) + a.right = applyTransferLambda(a.right) + return a + elif t == ast.IfExp: + # Sometimes, we can simplify the statement + test = constantFolding(a.test) + b = constantFolding(a.body) + o = constantFolding(a.orelse) + + aTest = astFormat(test) + transferMetaData(a.test, aTest) + aB = astFormat(b) + transferMetaData(a.body, aB) + aO = astFormat(o) + transferMetaData(a.orelse, aO) + + if type(test) == bool: + return aB if test else aO # evaluate the if expression now + elif compareASTs(b, o, checkEquality=True) == 0: + return aB # if they're the same, no reason for the expression + a.test = aTest + a.body = aB + a.orelse = aO + return a + elif t == ast.Compare: + if len(a.ops) == 0 or len(a.comparators) == 0: + return True # No ops? Okay, empty case is true! + op = a.ops[0] + l = constantFolding(a.left) + r = constantFolding(a.comparators[0]) + # Hack to make hint chaining work- don't constant-fold filler strings! + if containsTokenStepString(l) or containsTokenStepString(r): + tmpLeft = astFormat(l) + transferMetaData(a.left, tmpLeft) + a.left = tmpLeft + tmpRight = astFormat(r) + transferMetaData(a.comparators[0], tmpRight) + a.comparators = [tmpRight] + return a + # Check whether the two sides are the same + comp = compareASTs(l, r, checkEquality=True) == 0 + if comp and (not couldCrash(l)) and type(op) in [ast.Lt, ast.Gt, ast.NotEq]: + tmp = ast.NameConstant(False) + transferMetaData(a, tmp) + return tmp + elif comp and (not couldCrash(l)) and type(op) in [ast.Eq, ast.LtE, ast.GtE]: + tmp = ast.NameConstant(True) + transferMetaData(a, tmp) + return tmp + if (type(l) in builtInTypes) and (type(r) in builtInTypes): + try: + result = astFormat(doCompare(op, l, r)) + transferMetaData(a, result) + return result + except: + pass + # Reduce the expressions when possible! + if type(l) == type(r) == ast.BinOp and type(l.op) == type(r.op) and not couldCrash(l) and not couldCrash(r): + if type(l.op) == ast.Add: + # Remove repeated values + unchanged = False + if compareASTs(l.left, r.left, checkEquality=True) == 0: + l = l.right + r = r.right + elif compareASTs(l.right, r.right, checkEquality=True) == 0: + l = l.left + r = r.left + elif compareASTs(l.left, r.right, checkEquality=True) == 0 and eventualType(l) in [int, float]: + l = l.right + r = r.left + elif compareASTs(l.right, r.left, checkEquality=True) == 0 and eventualType(l) in [int, float]: + l = l.left + r = r.right + else: + unchanged = True + if not unchanged: + tmpLeft = astFormat(l) + transferMetaData(a.left, tmpLeft) + a.left = tmpLeft + tmpRight = astFormat(r) + transferMetaData(a.comparators[0], tmpRight) + a.comparators = [tmpRight] + return constantFolding(a) # Repeat this check to see if we can keep reducing it + elif type(l.op) == ast.Sub: + unchanged = False + if compareASTs(l.left, r.left, checkEquality=True) == 0: + l = l.right + r = r.right + elif compareASTs(l.right, r.right, checkEquality=True) == 0: + l = l.left + r = r.left + else: + unchanged = True + if not unchanged: + tmpLeft = astFormat(l) + transferMetaData(a.left, tmpLeft) + a.left = tmpLeft + tmpRight = astFormat(r) + transferMetaData(a.comparators[0], tmpRight) + a.comparators = [tmpRight] + return constantFolding(a) + tmpLeft = astFormat(l) + transferMetaData(a.left, tmpLeft) + a.left = tmpLeft + tmpRight = astFormat(r) + transferMetaData(a.comparators[0], tmpRight) + a.comparators = [tmpRight] + return a + elif t == ast.Call: + # TODO: this can be done much better + a.func = applyTransferLambda(a.func) + + allConstant = True + tmpArgs = [] + for i in range(len(a.args)): + tmpArgs.append(constantFolding(a.args[i])) + if type(tmpArgs[i]) not in [int, float, bool, str]: + allConstant = False + if len(a.keywords) > 0: + allConstant = False + if allConstant and (type(a.func) == ast.Name) and (a.func.id in builtInFunctions.keys()) and \ + (a.func.id not in ["range", "raw_input", "input", "open", "randint", "random", "slice"]): + try: + result = apply(eval(a.func.id), tmpArgs) + transferMetaData(a, astFormat(result)) + return result + except: + # Not gonna happen unless it crashes + #log("transformations\tconstantFolding\tFunction crashed: " + str(a.func.id), "bug") + pass + for i in range(len(a.args)): + tmpArg = astFormat(tmpArgs[i]) + transferMetaData(a.args[i], tmpArg) + a.args[i] = tmpArg + return a + # This needs to be separate because the attribute is a string + elif t == ast.Attribute: + a.value = applyTransferLambda(a.value) + return a + elif t == ast.Slice: + if a.lower != None: + a.lower = applyTransferLambda(a.lower) + if a.upper != None: + a.upper = applyTransferLambda(a.upper) + if a.step != None: + a.step = applyTransferLambda(a.step) + return a + elif t == ast.Num: + return a.n + elif t == ast.Bytes: + return a.s + elif t == ast.Str: + # Don't do things to filler strings + if len(a.s) > 0 and isTokenStepString(a.s): + return a + return a.s + elif t == ast.NameConstant: + if a.value == True: + return True + elif a.value == False: + return False + elif a.value == None: + return None + elif t == ast.Name: + return a + else: # All statements, ast.Lambda, ast.Dict, ast.Set, ast.Repr, ast.Attribute, ast.Subscript, etc. + return applyToChildren(a, applyTransferLambda) + +def isMutatingFunction(a): + """Given a function call, this checks whether it might change the program state when run""" + if type(a) != ast.Call: # Can only call this on Calls! + log("transformations\tisMutatingFunction\tNot a Call: " + str(type(a)), "bug") + return True + + # Map of all static namesets + funMaps = { "math" : mathFunctions, "string" : builtInStringFunctions, + "str" : builtInStringFunctions, "list" : staticListFunctions, + "dict" : staticDictFunctions } + typeMaps = { str : "string", list : "list", dict : "dict" } + if type(a.func) == ast.Name: + funDict = builtInFunctions + funName = a.func.id + elif type(a.func) == ast.Attribute: + if type(a.func.value) == ast.Name and a.func.value.id in funMaps: + funDict = funMaps[a.func.value.id] + funName = a.func.attr + # if the item is calling a function directly + elif eventualType(a.func.value) in typeMaps: + funDict = funMaps[typeMaps[eventualType(a.func.value)]] + funName = a.func.attr + else: + return True + else: + return True # we don't know, so yes + + # TODO: deal with student's functions + return funName not in funDict + +def allVariablesUsed(a): + if not isinstance(a, ast.AST): + return [] + elif type(a) == ast.Name: + return [a] + variables = [] + for child in ast.iter_child_nodes(a): + variables += allVariablesUsed(child) + return variables + +def allVariableNamesUsed(a): + """Gathers all the variable names used in the ast""" + if not isinstance(a, ast.AST): + return [] + elif type(a) == ast.Name: + return [a.id] + elif type(a) == ast.Assign: + """In assignments, ignore all pure names used- they're being assigned to, not used""" + variables = allVariableNamesUsed(a.value) + for target in a.targets: + if type(target) == ast.Name: + pass + elif type(target) in [ast.Tuple, ast.List]: + for elt in target.elts: + if type(elt) != ast.Name: + variables += allVariableNamesUsed(elt) + else: + variables += allVariableNamesUsed(target) + return variables + elif type(a) == ast.AugAssign: + variables = allVariableNamesUsed(a.value) + variables += allVariableNamesUsed(a.target) + return variables + variables = [] + for child in ast.iter_child_nodes(a): + variables += allVariableNamesUsed(child) + return variables + +def addPropTag(a, globalId): + if not isinstance(a, ast.AST): + return a + a.propagatedVariable = True + if hasattr(a, "global_id"): + a.variableGlobalId = globalId + return applyToChildren(a, lambda x : addPropTag(x, globalId)) + +def propagateValues(a, liveVars): + """Propagate the given values through the AST whenever their variables occur""" + if ((not isinstance(a, ast.AST) or len(liveVars.keys()) == 0)): + return a + + if type(a) == ast.Name: + # Propagate the value if we have it! + if a.id in liveVars: + val = copy.deepcopy(liveVars[a.id]) + val.loadedVariable = True + if hasattr(a, "global_id"): + val.variableGlobalId = a.global_id + return applyToChildren(val, lambda x : addPropTag(x, a.global_id)) + else: + return a + elif type(a) == ast.Call: + # If something is mutated, it cannot be propagated anymore + if isMutatingFunction(a): + allVars = allVariablesUsed(a) + for var in allVars: + if (eventualType(var) not in [int, float, bool, str]): + if (var.id in liveVars): + del liveVars[var.id] + currentLiveVars = list(liveVars.keys()) + for liveVar in currentLiveVars: + varsWithin = allVariableNamesUsed(liveVars[liveVar]) + if var.id in varsWithin: + del liveVars[liveVar] + return a + elif type(a.func) == ast.Name and a.func.id in liveVars and \ + eventualType(liveVars[a.func.id]) in [int, float, complex, bytes, bool, type(None)]: + # Special case: don't move a simple value to the front of a Call + # because it will cause a compiler error instead of a runtime error + a.args = propagateValues(a.args, liveVars) + a.keywords = propagateValues(a.keywords, liveVars) + return a + elif type(a) == ast.Attribute: + if type(a.value) == ast.Name and a.value.id in liveVars and \ + eventualType(liveVars[a.value.id]) in [int, float, complex, bytes, bool, type(None)]: + # Don't move for the same reason as above + return a + return applyToChildren(a, lambda x: propagateValues(x, liveVars)) + +def hasMutatingFunction(a): + """Checks to see if the ast has any potentially mutating functions""" + if not isinstance(a, ast.AST): + return False + for node in ast.walk(a): + if type(a) == ast.Call: + if isMutatingFunction(a): + return True + return False + +def clearBlockVars(a, liveVars): + """Clear all the vars set in this block out of the live vars""" + if (not isinstance(a, ast.AST)) or len(liveVars.keys()) == 0: + return + + if type(a) in [ast.Assign, ast.AugAssign]: + if type(a) == ast.Assign: + targets = gatherAssignedVars(a.targets) + else: + targets = gatherAssignedVars([a.target]) + for target in targets: + varId = None + if type(target) == ast.Name: + varId = target.id + elif type(target.value) == ast.Name: + varId = target.value.id + if varId in liveVars: + del liveVars[varId] + + liveKeys = list(liveVars.keys()) + for var in liveKeys: + # Remove the variable and any variables in which it is used + if varId in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + return + elif type(a) == ast.Call: + if hasMutatingFunction(a): + for v in allVariablesUsed(a): + if eventualType(v) not in [int, float, bool, str]: + if v.id in liveVars: + del liveVars[v.id] + liveKeys = list(liveVars.keys()) + for var in liveKeys: + if v.id in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + return + elif type(a) == ast.For: + names = [] + if type(a.target) == ast.Name: + names = [a.target.id] + elif type(a.target) in [ast.Tuple, ast.List]: + for elt in a.target.elts: + if type(elt) == ast.Name: + names.append(elt.id) + elif type(elt) == ast.Subscript: + if type(elt.value) == ast.Name: + names.append(elt.value.id) + else: + log("transformations\tclearBlockVars\tFor target subscript not a name: " + str(type(elt.value)), "bug") + else: + log("transformations\tclearBlockVars\tFor target not a name: " + str(type(elt)), "bug") + elif type(a.target) == ast.Subscript: + if type(a.target.value) == ast.Name: + names.append(a.target.value.id) + else: + log("transformations\tclearBlockVars\tFor target subscript not a name: " + str(type(a.target.value)), "bug") + else: + log("transformations\tclearBlockVars\tFor target not a name: " + str(type(a.target)), "bug") + for name in names: + if name in liveVars: + del liveVars[name] + + liveKeys = list(liveVars.keys()) + for var in liveKeys: + # Remove the variable and any variables in which it is used + if name in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + + for child in ast.iter_child_nodes(a): + clearBlockVars(child, liveVars) + +def copyPropagation(a, liveVars=None, inLoop=False): + """Propagate variables into the tree, when possible""" + if liveVars == None: + liveVars = { } + if type(a) == ast.Module: + a.body = copyPropagation(a.body) + return a + if type(a) == ast.FunctionDef: + a.body = copyPropagation(a.body, liveVars=liveVars) + return a + + if type(a) == list: + i = 0 + while i < len(a): + deleteLine = False + if type(a[i]) == ast.FunctionDef: + a[i].body = copyPropagation(a[i].body, liveVars=copy.deepcopy(liveVars)) + elif type(a[i]) == ast.ClassDef: + # TODO: can we propagate values through everything after here? + for j in range(len(a[i].body)): + if type(a[i].body[j]) == ast.FunctionDef: + a[i].body[j] = copyPropagation(a[i].body[j]) + elif type(a[i]) == ast.Assign: + # In assignments, propagate values into the right side and move the left side into the live vars + a[i].value = propagateValues(a[i].value, liveVars) + target = a[i].targets[0] + + if type(target) in [ast.Name, ast.Subscript, ast.Attribute]: + varId = None + # In plain names, we can update the liveVars + if type(target) == ast.Name: + varId = target.id + if inLoop or couldCrash(a[i].value) or eventualType(a[i].value) not in [bool, int, float, str, tuple]: + # Remove this variable from the live vars + if varId in liveVars: + del liveVars[varId] + else: + liveVars[varId] = a[i].value + # For other values, we can at least clear out liveVars correctly + # TODO: can we expand this? + elif target.value == ast.Name: + varId = target.value.id + + # Now, update the live vars based on anything reset by the new target + liveKeys = list(liveVars.keys()) + for var in liveKeys: + # If the var we're replacing was used elsewhere, that value will no longer be the same + if varId in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + elif type(target) in [ast.Tuple, ast.List]: + # Copy the values, if we can match them + if type(a[i].value) in [ast.Tuple, ast.List] and len(target.elts) == len(a[i].value.elts): + for j in range(len(target.elts)): + if type(target.elts[j]) == ast.Name: + if (not couldCrash(a[i].value.elts[j])): + liveVars[target.elts[j]] = a[i].value.elts[j] + else: + if target.elts[j] in liveVars: + del liveVars[target.elts[j]] + + # Then get rid of any overwrites + for e in target.elts: + if type(e) in [ast.Name, ast.Subscript, ast.Attribute]: + varId = None + if type(e) == ast.Name: + varId = e.id + elif type(e.value) == ast.Name: + varId = e.value.id + + liveKeys = list(liveVars.keys()) + for var in liveKeys: + if varId in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + else: + log("transformations\tcopyPropagation\tWeird assign type: " + str(type(e)), "bug") + elif type(a[i]) == ast.AugAssign: + a[i].value = propagateValues(a[i].value, liveVars) + assns = gatherAssignedVarIds([a[i].target]) + for target in assns: + if target in liveVars: + del liveVars[target] + elif type(a[i]) == ast.For: + # FIRST, propagate values into the iter + if type(a[i].iter) != ast.Name: # if it IS a name, don't replace it! + # Otherwise, we propagate first since this is evaluated once + a[i].iter = propagateValues(a[i].iter, liveVars) + + # We reset the target variable, so reset the live vars + names = [] + if type(a[i].target) == ast.Name: + names = [a[i].target.id] + elif type(a[i].target) in [ast.Tuple, ast.List]: + for elt in a[i].target.elts: + if type(elt) == ast.Name: + names.append(elt.id) + elif type(elt) == ast.Subscript: + if type(elt.value) == ast.Name: + names.append(elt.value.id) + else: + log("transformations\tcopyPropagation\tFor target subscript not a name: " + str(type(elt.value)) + "\t" + printFunction(elt.value), "bug") + else: + log("transformations\tcopyPropagation\tFor target not a name: " + str(type(elt)) + "\t" + printFunction(elt), "bug") + elif type(a[i].target) == ast.Subscript: + if type(a[i].target.value) == ast.Name: + names.append(a[i].target.value.id) + else: + log("transformations\tcopyPropagation\tFor target subscript not a name: " + str(type(a[i].target.value)) + "\t" + printFunction(a[i].target.value), "bug") + else: + log("transformations\tcopyPropagation\tFor target not a name: " + str(type(a[i].target)) + "\t" + printFunction(a[i].target), "bug") + + for name in names: + liveKeys = list(liveVars.keys()) + for var in liveKeys: + if name in allVariableNamesUsed(liveVars[var]): + del liveVars[var] + if name in liveVars: + del liveVars[name] + clearBlockVars(a[i], liveVars) + a[i].body = copyPropagation(a[i].body, copy.deepcopy(liveVars), inLoop=True) + a[i].orelse = copyPropagation(a[i].orelse, copy.deepcopy(liveVars), inLoop=True) + elif type(a[i]) == ast.While: + clearBlockVars(a[i], liveVars) + a[i].test = propagateValues(a[i].test, liveVars) + a[i].body = copyPropagation(a[i].body, copy.deepcopy(liveVars), inLoop=True) + a[i].orelse = copyPropagation(a[i].orelse, copy.deepcopy(liveVars), inLoop=True) + elif type(a[i]) == ast.If: + a[i].test = propagateValues(a[i].test, liveVars) + liveVars1 = copy.deepcopy(liveVars) + liveVars2 = copy.deepcopy(liveVars) + a[i].body = copyPropagation(a[i].body, liveVars1) + a[i].orelse = copyPropagation(a[i].orelse, liveVars2) + liveVars.clear() + # We can keep any values that occur in both + for key in liveVars1: + if key in liveVars2: + if compareASTs(liveVars1[key], liveVars2[key], checkEquality=True) == 0: + liveVars[key] = liveVars1[key] + # TODO: think more deeply about how this should work + elif type(a[i]) == ast.Try: + a[i].body = copyPropagation(a[i].body, liveVars) + for handler in a[i].handlers: + handler.body = copyPropagation(handler.body, liveVars) + a[i].orelse = copyPropagation(a[i].orelse, liveVars) + a[i].finalbody = copyPropagation(a[i].finalbody, liveVars) + elif type(a[i]) == ast.With: + a[i].body = copyPropagation(a[i].body, liveVars) + # With regular statements, just propagate the values + elif type(a[i]) in [ast.Return, ast.Delete, ast.Raise, ast.Assert, ast.Expr]: + propagateValues(a[i], liveVars) + # Breaks and Continues mess everything up + elif type(a[i]) in [ast.Break, ast.Continue]: + break + # These are not affected by this function + elif type(a[i]) in [ast.Import, ast.ImportFrom, ast.Global, ast.Pass]: + pass + else: + log("transformations\tcopyPropagation\tNot implemented: " + str(type(a[i])), "bug") + i += 1 + return a + else: + log("transformations\tcopyPropagation\tNot a list: " + str(type(a)), "bug") + return a + +def deadCodeRemoval(a, liveVars=None, keepPrints=True, inLoop=False): + """Remove any code which will not be reached or used.""" + """LiveVars keeps track of the variables that will be necessary""" + if liveVars == None: + liveVars = set() + if type(a) == ast.Module: + # Remove functions that will be overwritten anyway + namesSeen = [] + i = len(a.body) - 1 + while i >= 0: + if type(a.body[i]) == ast.FunctionDef: + if a.body[i].name in namesSeen: + # SPECIAL CHECK! Actually, the function will cause the code to crash if some of the args have the same name. Don't delete it then. + argNames = [] + for arg in a.body[i].args.args: + if arg.arg in argNames: + break + else: + argNames.append(arg.arg) + else: # only remove this if the args won't break it + a.body.pop(i) + else: + namesSeen.append(a.body[i].name) + elif type(a.body[i]) == ast.Assign: + namesSeen += gatherAssignedVars(a.body[i].targets) + i -= 1 + liveVars |= set(namesSeen) # make sure all global names are used! + + if type(a) in [ast.Module, ast.FunctionDef]: + if type(a) == ast.Module and len(a.body) == 0: + return a # just don't mess with it + gid = a.body[0].global_id if len(a.body) > 0 and hasattr(a.body[0], "global_id") else None + a.body = deadCodeRemoval(a.body, liveVars=liveVars, keepPrints=keepPrints, inLoop=inLoop) + if len(a.body) == 0: + a.body = [ast.Pass(removedLines=True)] if gid == None else [ast.Pass(removedLines=True, global_id=gid)] + return a + + if type(a) == list: + i = len(a) - 1 + while i >= 0 and len(a) > 0: + if i >= len(a): + i = len(a) - 1 # just in case + stmt = a[i] + t = type(stmt) + # TODO: get rid of these if they aren't live + if t in [ast.FunctionDef, ast.ClassDef]: + newLiveVars = set() + gid = a[i].body[0].global_id if len(a[i].body) > 0 and hasattr(a[i].body[0], "global_id") else None + a[i] = deadCodeRemoval(a[i], liveVars=newLiveVars, keepPrints=keepPrints, inLoop=inLoop) + liveVars |= newLiveVars + # Empty functions are useless! + if len(a[i].body) == 0: + a[i].body = [ast.Pass(removedLines=True)] if gid == None else [ast.Pass(removedLines=True, global_id=gid)] + elif t == ast.Return: + # Get rid of everything that happens after this! + a = a[:i+1] + # Replace the variables + liveVars.clear() + liveVars |= set(allVariableNamesUsed(stmt)) + elif t in [ast.Delete, ast.Assert]: + # Just add all variables used + liveVars |= set(allVariableNamesUsed(stmt)) + elif t == ast.Assign: + # Check to see if the names being assigned are in the set of live variables + allDead = True + allTargets = gatherAssignedVars(stmt.targets) + allNamesUsed = allVariableNamesUsed(stmt.value) + for target in allTargets: + if type(target) == ast.Name and (target.id in liveVars or target.id in allNamesUsed): + if target.id in liveVars: + liveVars.remove(target.id) + allDead = False + elif type(target) in [ast.Subscript, ast.Attribute]: + liveVars |= set(allVariableNamesUsed(target)) + allDead = False + # Also, check if the variable itself is contained in the value, because that can crash too + # If none are used, we can delete this line. Otherwise, use the value's vars + if allDead and (not couldCrash(stmt)) and (not containsTokenStepString(stmt)): + a.pop(i) + else: + liveVars |= set(allVariableNamesUsed(stmt.value)) + elif t == ast.AugAssign: + liveVars |= set(allVariableNamesUsed(stmt.target)) + liveVars |= set(allVariableNamesUsed(stmt.value)) + elif t == ast.For: + # If there is no use of break, there's no reason to use else with the loop, + # so move the lines outside and go over them separately + if len(stmt.orelse) > 0 and countOccurances(stmt, ast.Break) == 0: + lines = stmt.orelse + stmt.orelse = [] + a[i:i+1] = [stmt] + lines + i += len(lines) + continue # don't subtract one + + targetNames = [] + if type(a[i].target) == ast.Name: + targetNames = [a[i].target.id] + elif type(a[i].target) in [ast.Tuple, ast.List]: + for elt in a[i].target.elts: + if type(elt) == ast.Name: + targetNames.append(elt.id) + elif type(elt) == ast.Subscript: + if type(elt.value) == ast.Name: + targetNames.append(elt.value.id) + else: + log("transformations\tdeadCodeRemoval\tFor target subscript not a name: " + str(type(elt.value)) + "\t" + printFunction(elt.value), "bug") + else: + log("transformations\tdeadCodeRemoval\tFor target not a name: " + str(type(elt)) + "\t" + printFunction(elt), "bug") + elif type(a[i].target) == ast.Subscript: + if type(a[i].target.value) == ast.Name: + targetNames.append(a[i].target.value.id) + else: + log("transformations\tdeadCodeRemoval\tFor target subscript not a name: " + str(type(a[i].target.value)) + "\t" + printFunction(a[i].target.value), "bug") + else: + log("transformations\tdeadCodeRemoval\tFor target not a name: " + str(type(a[i].target)) + "\t" + printFunction(a[i].target), "bug") + + # We need to make ALL variables in the loop live, since they update continuously + liveVars |= set(allVariableNamesUsed(stmt)) + gid = stmt.body[0].global_id if len(stmt.body) > 0 and hasattr(stmt.body[0], "global_id") else None + stmt.body = deadCodeRemoval(stmt.body, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=True) + stmt.orelse = deadCodeRemoval(stmt.orelse, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=inLoop) + # If the body is empty and we don't need the target, get rid of it! + if len(stmt.body) == 0: + for name in targetNames: + if name in liveVars: + stmt.body = [ast.Pass(removedLines=True)] if gid == None else [ast.Pass(removedLines=True, global_id=gid)] + break + else: + if couldCrash(stmt.iter) or containsTokenStepString(stmt.iter): + a[i] = ast.Expr(stmt.iter, collapsedExpr=True) + else: + a.pop(i) + if len(stmt.orelse) > 0: + a[i:i+1] = a[i] + stmt.orelse + + # The names are wiped UPDATE - NOPE, what if we never enter the loop? + #for name in targetNames: + # liveVars.remove(name) + elif t == ast.While: + # If there is no use of break, there's no reason to use else with the loop, + # so move the lines outside and go over them separately + if len(stmt.orelse) > 0 and countOccurances(stmt, ast.Break) == 0: + lines = stmt.orelse + stmt.orelse = [] + a[i:i+1] = [stmt] + lines + i += len(lines) + continue + + # We need to make ALL variables in the loop live, since they update continuously + liveVars |= set(allVariableNamesUsed(stmt)) + old_global_id = stmt.body[0].global_id + stmt.body = deadCodeRemoval(stmt.body, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=True) + stmt.orelse = deadCodeRemoval(stmt.orelse, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=inLoop) + # If the body is empty, get rid of it! + if len(stmt.body) == 0: + stmt.body = [ast.Pass(removedLines=True, global_id=old_global_id)] + elif t == ast.If: + # First, if True/False, just replace it with the lines + test = a[i].test + if type(test) == ast.NameConstant and test.value in [True, False]: + assignedVars = getAllAssignedVars(a[i]) + for var in assignedVars: + # UNLESS we have a weird variable assignment problem + if var.id[0] == "g" and hasattr(var, "originalId"): + log("canonicalize\tdeadCodeRemoval\tWeird global variable: " + printFunction(a[i]), "bug") + break + else: + if test.value == True: + a[i:i+1] = a[i].body + else: + a[i:i+1] = a[i].orelse + continue + # For if statements, see if you can shorten things + liveVars1 = copy.deepcopy(liveVars) + liveVars2 = copy.deepcopy(liveVars) + stmt.body = deadCodeRemoval(stmt.body, liveVars1, keepPrints=keepPrints, inLoop=inLoop) + stmt.orelse = deadCodeRemoval(stmt.orelse, liveVars2, keepPrints=keepPrints, inLoop=inLoop) + liveVars.clear() + allVars = liveVars1 | liveVars2 | set(allVariableNamesUsed(stmt.test)) + liveVars |= allVars + if len(stmt.body) == 0 and len(stmt.orelse) == 0: + # Get rid of the if and keep going + if couldCrash(stmt.test) or containsTokenStepString(stmt.test): + newStmt = ast.Expr(stmt.test, collapsedExpr=True) + transferMetaData(stmt, newStmt) + a[i] = newStmt + else: + a.pop(i) + i -= 1 + continue + if len(stmt.body) == 0: + # If the body is empty, switch it with the else + stmt.test = deMorganize(ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True)) + (stmt.body, stmt.orelse) = (stmt.orelse, stmt.body) + if len(stmt.orelse) == 0: + # See if we can make the rest of the function the else statement + if type(stmt.body[-1]) == type(a[-1]) == ast.Return: + # If the if is larger than the rest, switch them! + if len(stmt.body) > len(a[i+1:]): + stmt.test = deMorganize(ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True)) + (a[i+1:], stmt.body) = (stmt.body, a[i+1:]) + else: + # Check to see if we should switch the if and else parts + if len(stmt.body) > len(stmt.orelse): + stmt.test = deMorganize(ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True)) + (stmt.body, stmt.orelse) = (stmt.orelse, stmt.body) + elif t == ast.Import: + if len(stmt.names) == 0: + a.pop(i) + elif t == ast.Global: + j = 0 + while j < len(a.names): + if a.names[j] not in liveVars: + a.names.pop(j) + else: + j += 1 + elif t == ast.Expr: + # Remove the line if it won't crash things. + if couldCrash(stmt) or containsTokenStepString(stmt): + liveVars |= set(allVariableNamesUsed(stmt)) + else: + # check whether any of these variables might crash the program + # I know, it's weird, but occasionally a student might use a var before defining it + allVars = allVariableNamesUsed(stmt) + for j in range(i): + if type(a[j]) == ast.Assign: + for id in gatherAssignedVarIds(a[j].targets): + if id in allVars: + allVars.remove(id) + if len(allVars) > 0: + liveVars |= set(allVariableNamesUsed(stmt)) + else: + a.pop(i) + # for now, just be careful with these types of statements + elif t in [ast.With, ast.Raise, ast.Try]: + liveVars |= set(allVariableNamesUsed(stmt)) + elif t == ast.Pass: + a.pop(i) # pass does *nothing* + elif t in [ast.Continue, ast.Break]: + if inLoop: # If we're in a loop, nothing that follows matters! Otherwise, leave it alone, this will just crash. + a = a[:i+1] + break + # We don't know what they're doing- leave it alone + elif t in [ast.ImportFrom]: + pass + # Have not yet implemented these + else: + log("transformations\tdeadCodeRemoval\tNot implemented: " + str(type(stmt)), "bug") + i -= 1 + return a + else: + log("transformations\tdeadCodeRemoval\tNot a list: " + str(a), "bug") + return a + +### ORDERING FUNCTIONS ### + +def getKeyDict(d, key): + if key not in d: + d[key] = { "self" : key } + return d[key] + +def traverseTrail(d, trail): + temp = d + for key in trail: + temp = temp[key] + return temp + +def areDisjoint(a, b): + """Are the sets of values that satisfy these two boolean constraints disjoint?""" + # The easiest way to be disjoint is to have comparisons that cover different areas + if type(a) == type(b) == ast.Compare: + aop = a.ops[0] + bop = b.ops[0] + aLeft = a.left + aRight = a.comparators[0] + bLeft = b.left + bRight = b.comparators[0] + alblComp = compareASTs(aLeft, bLeft, checkEquality=True) + albrComp = compareASTs(aLeft, bRight, checkEquality=True) + arblComp = compareASTs(aRight, bLeft, checkEquality=True) + arbrComp = compareASTs(aRight, bRight, checkEquality=True) + altype = type(aLeft) in [ast.Num, ast.Str] + artype = type(aRight) in [ast.Num, ast.Str] + bltype = type(bLeft) in [ast.Num, ast.Str] + brtype = type(bRight) in [ast.Num, ast.Str] + + if (type(aop) == ast.Eq and type(bop) == ast.NotEq) or \ + (type(bop) == ast.Eq and type(aop) == ast.NotEq): + # x == y, x != y + if (alblComp == 0 and arbrComp == 0) or (albrComp == 0 and arblComp == 0): + return True + elif type(aop) == type(bop) == ast.Eq: + if (alblComp == 0 and arbrComp == 0) or (albrComp == 0 and arblComp == 0): + return False + # x = num1, x = num2 + elif alblComp == 0 and artype and brtype: + return True + elif albrComp == 0 and artype and bltype: + return True + elif arblComp == 0 and altype and brtype: + return True + elif arbrComp == 0 and altype and bltype: + return True + elif (type(aop) == ast.Lt and type(bop) == ast.GtE) or \ + (type(aop) == ast.Gt and type(bop) == ast.LtE) or \ + (type(aop) == ast.LtE and type(bop) == ast.Gt) or \ + (type(aop) == ast.GtE and type(bop) == ast.Lt) or \ + (type(aop) == ast.Is and type(bop) == ast.IsNot) or \ + (type(aop) == ast.IsNot and type(bop) == ast.Is) or \ + (type(aop) == ast.In and type(bop) == ast.NotIn) or \ + (type(aop) == ast.NotIn and type(bop) == ast.In): + if alblComp == 0 and arbrComp == 0: + return True + elif (type(aop) == ast.Lt and type(bop) == ast.LtE) or \ + (type(aop) == ast.Gt and type(bop) == ast.GtE) or \ + (type(aop) == ast.LtE and type(bop) == ast.Lt) or \ + (type(aop) == ast.GtE and type(bop) == ast.Gt): + if albrComp == 0 and arblComp == 0: + return True + elif type(a) == type(b) == ast.BoolOp: + return False # for now- TODO: when is this not true? + elif type(a) == ast.UnaryOp and type(a.op) == ast.Not: + if compareASTs(a.operand, b, checkEquality=True) == 0: + return True + elif type(b) == ast.UnaryOp and type(b.op) == ast.Not: + if compareASTs(b.operand, a, checkEquality=True) == 0: + return True + return False + +def crashesOn(a): + """Determines where the expression might crash""" + # TODO: integrate typeCrashes + if not isinstance(a, ast.AST): + return [] + if type(a) == ast.BinOp: + l = eventualType(a.left) + r = eventualType(a.right) + if type(a.op) == ast.Add: + if not ((l == r == str) or (l in [int, float] and r in [int, float])): + return [a] + elif type(a.op) == ast.Mult: + if not ((l == str and r == int) or (l == int and r == str) or \ + (l in [int, float] and r in [int, float])): + return [a] + elif type(a.op) in [ast.Sub, ast.Pow, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd]: + if l not in [int, float] or r not in [int, float]: + return [a] + else: # ast.Div, ast.FloorDiv, ast.Mod + if (type(a.right) != ast.Num or a.right.n == 0) or \ + (l not in [int, float] or r not in [int, float]): + return [a] + elif type(a) == ast.UnaryOp: + if type(a.op) in [ast.UAdd, ast.USub]: + if eventualType(a.operand) not in [int, float]: + return [a] + elif type(a.op) == ast.Invert: + if eventualType(a.operand) != int: + return [a] + elif type(a) == ast.Compare: + if len(a.ops) != len(a.comparators): + return [a] + elif type(a.ops[0]) in [ast.In, ast.NotIn] and not isIterableType(eventualType(a.comparators[0])): + return [a] + elif type(a.ops[0]) in [ast.Lt, ast.LtE, ast.Gt, ast.GtE]: + # In Python3, you can't compare different types. BOOOOOO!! + firstType = eventualType(a.left) + if firstType == None: + return [a] + for comp in a.comparators: + if eventualType(comp) != firstType: + return [a] + elif type(a) == ast.Call: + env = [] # TODO: what if the environments aren't imported? + funMaps = { "math" : mathFunctions, "string" : builtInStringFunctions } + safeFunMaps = { "math" : safeMathFunctions, "string" : safeStringFunctions } + if type(a.func) == ast.Name: + funDict = builtInFunctions + safeFuns = builtInSafeFunctions + funName = a.func.id + elif type(a.func) == ast.Attribute: + if type(a.func.value) == ast.Name and a.func.value.id in funMaps: + funDict = funMaps[a.func.value.id] + safeFuns = safeFunMaps[a.func.value.id] + funName = a.func.attr + elif eventualType(a.func.value) == str: + funDict = funMaps["string"] + safeFuns = safeFunMaps["string"] + funName = a.func.attr + else: # including list and dict + return [a] + else: + return [a] + + runOnce = 0 # So we can break + while (runOnce == 0): + if funName in safeFuns: + argTypes = [] + for i in range(len(a.args)): + eventual = eventualType(a.args[i]) + if eventual == None: + return [a] + argTypes.append(eventual) + + if funName in ["max", "min"]: + break # Special functions + + for key in funDict[funName]: + if len(key) != len(argTypes): + continue + for i in range(len(key)): + if not (key[i] == argTypes[i] or issubclass(argTypes[i], key[i])): + break + else: + break + else: + return [a] + break # found one that works + else: + return [a] + elif type(a) == ast.Subscript: + if eventualType(a.value) not in [str, list, tuple]: + return [a] + elif type(a) == ast.Name: + # If it's an undefined variable, it might crash + if hasattr(a, "randomVar"): + return [a] + elif type(a) == ast.Slice: + if a.lower != None and eventualType(a.lower) != int: + return [a] + if a.upper != None and eventualType(a.upper) != int: + return [a] + if a.step != None and eventualType(a.step) != int: + return [a] + elif type(a) in [ast.Assert, ast.Import, ast.ImportFrom, ast.Attribute, ast.Index]: + return [a] + + allCrashes = [] + for child in ast.iter_child_nodes(a): + allCrashes += crashesOn(child) + return allCrashes + +def isNegation(a, b): + """Is a the negation of b?""" + return compareASTs(deMorganize(ast.UnaryOp(ast.Not(), deepcopy(a))), b, checkEquality=True) == 0 + +def reverse(op): + """Reverse the direction of the comparison for normalization purposes""" + rev = not op.reversed if hasattr(op, "reversed") else True + if type(op) == ast.Gt: + newOp = ast.Lt() + transferMetaData(op, newOp) + newOp.reversed = rev + return newOp + elif type(op) == ast.GtE: + newOp = ast.LtE() + transferMetaData(op, newOp) + newOp.reversed = rev + return newOp + else: + return op # Do not change! + +def orderCommutativeOperations(a): + """Order all expressions that are in commutative operations""" + """TODO: add commutative function lines?""" + if not isinstance(a, ast.AST): + return a + # If branches can be commutative as long as their tests are disjoint + if type(a) == ast.If: + a = applyToChildren(a, orderCommutativeOperations) + # If the else is (strictly) shorter than the body, switch them + if len(a.orelse) != 0 and len(a.body) > len(a.orelse): + newTest = ast.UnaryOp(ast.Not(addedNotOp=True), a.test) + transferMetaData(a.test, newTest) + newTest.negated = True + newTest = deMorganize(newTest) + a.test = newTest + (a.body,a.orelse) = (a.orelse,a.body) + + # Then collect all the branches. The leftover orelse is the final else + branches = [(a.test, a.body, a.global_id)] + orElse = a.orelse + while len(orElse) == 1 and type(orElse[0]) == ast.If: + branches.append((orElse[0].test, orElse[0].body, orElse[0].global_id)) + orElse = orElse[0].orelse + + # If we have branches to order... + if len(branches) != 1: + # Sort the branches based on their tests + # We have to sort carefully because of the possibility for crashing + isSorted = False + while not isSorted: + isSorted = True + for i in range(len(branches)-1): + # First, do we even want to swap these two? + # Branch tests MUST be disjoint to be swapped- otherwise, we break semantics + if areDisjoint(branches[i][0], branches[i+1][0]) and \ + compareASTs(branches[i][0], branches[i+1][0]) > 0: + if not (couldCrash(branches[i][0]) or couldCrash(branches[i+1][0])): + (branches[i],branches[i+1]) = (branches[i+1],branches[i]) + isSorted = False + # Two values can be swapped if they crash on the SAME thing + elif couldCrash(branches[i][0]) and couldCrash(branches[i+1][0]): + # Check to see if they crash on the same things + l1 = sorted(crashesOn(branches[i][0]), key=functools.cmp_to_key(compareASTs)) + l2 = sorted(crashesOn(branches[i+1][0]), key=functools.cmp_to_key(compareASTs)) + if compareASTs(l1, l2, checkEquality=True) == 0: + (branches[i],branches[i+1]) = (branches[i+1],branches[i]) + isSorted = False + # Do our last two branches nicely form an if/else already? + if len(orElse) == 0 and isNegation(branches[-1][0], branches[-2][0]): + starter = branches[-1][1] # skip the if + else: + starter = [ast.If(branches[-1][0], branches[-1][1], orElse, global_id=branches[-1][2])] + # Create the new conditional tree + for i in range(len(branches)-2, -1, -1): + starter = [ast.If(branches[i][0], branches[i][1], starter, global_id=branches[i][2])] + a = starter[0] + return a + elif type(a) == ast.BoolOp: + # If all the values are booleans and won't crash, we can sort them + canSort = True + for i in range(len(a.values)): + a.values[i] = orderCommutativeOperations(a.values[i]) + if couldCrash(a.values[i]) or eventualType(a.values[i]) != bool or containsTokenStepString(a.values[i]): + canSort = False + + if canSort: + a.values = sorted(a.values, key=functools.cmp_to_key(compareASTs)) + else: + # Even if there are some problems, we can partially sort. See above + isSorted = False + while not isSorted: + isSorted = True + for i in range(len(a.values)-1): + if compareASTs(a.values[i], a.values[i+1]) > 0 and \ + eventualType(a.values[i]) == bool and eventualType(a.values[i+1]) == bool: + if not (couldCrash(a.values[i]) or couldCrash(a.values[i+1])): + (a.values[i],a.values[i+1]) = (a.values[i+1],a.values[i]) + isSorted = False + # Two values can also be swapped if they crash on the SAME thing + elif couldCrash(a.values[i]) and couldCrash(a.values[i+1]): + # Check to see if they crash on the same things + l1 = sorted(crashesOn(a.values[i]), key=functools.cmp_to_key(compareASTs)) + l2 = sorted(crashesOn(a.values[i+1]), key=functools.cmp_to_key(compareASTs)) + if compareASTs(l1, l2, checkEquality=True) == 0: + (a.values[i],a.values[i+1]) = (a.values[i+1],a.values[i]) + isSorted = False + return a + elif type(a) == ast.BinOp: + top = type(a.op) + l = a.left = orderCommutativeOperations(a.left) + r = a.right = orderCommutativeOperations(a.right) + + # Don't reorder if we're currently walking through hint steps + if containsTokenStepString(l) or containsTokenStepString(r): + return a + + # TODO: what about possible crashes? + # Certain operands are commutative + if (top in [ast.Mult, ast.BitOr, ast.BitXor, ast.BitAnd]) or \ + ((top == ast.Add) and ((eventualType(l) in [int, float, bool]) or \ + (eventualType(r) in [int, float, bool]))): + # Break the chain of binary operations into a list of the + # operands over the same op, then sort the operands + operands = [[l, a.op], [r, None]] + changeMade = True + i = 0 + while i < len(operands): + [operand, op] = operands[i] + if type(operand) == ast.BinOp and type(operand.op) == top: + operands[i:i+1] = [[operand.left, operand.op], [operand.right, op]] + else: + i += 1 + operands = sorted(operands, key=functools.cmp_to_key(lambda x,y : compareASTs(x[0], y[0]))) + for i in range(len(operands)-1): # push all the ops forward + if operands[i][1] == None: + operands[i][1] = operands[i+1][1] + operands[i+1][1] = None + # Then put them back into a single expression, descending to the left + left = operands[0][0] + for i in range(1, len(operands)): + left = ast.BinOp(left, operands[i-1][1], operands[i][0], orderedBinOp=True) + transferMetaData(a, left) + return left + elif top == ast.Add: + # This might be concatenation, not addition + if type( r ) == ast.BinOp and type(r.op) == top: + # We want the operators to descend to the left + a.left = orderCommutativeOperations(ast.BinOp(l, r.op, r.left, global_id=r.global_id)) + a.right = r.right + return a + elif type(a) == ast.Dict: + for i in range(len(a.keys)): + a.keys[i] = orderCommutativeOperations(a.keys[i]) + a.values[i] = orderCommutativeOperations(a.values[i]) + + pairs = list(zip(a.keys, a.values)) + pairs.sort(key=functools.cmp_to_key(lambda x,y : compareASTs(x[0],y[0]))) # sort by keys + k, v = zip(*pairs) if len(pairs) > 0 else ([], []) + a.keys = list(k) + a.values = list(v) + return a + elif type(a) == ast.Compare: + l = a.left = orderCommutativeOperations(a.left) + r = orderCommutativeOperations(a.comparators[0]) + a.comparators = [r] + + # Don't reorder when we're doing hint steps + if containsTokenStepString(l) or containsTokenStepString(r): + return a + + if (type(a.ops[0]) in [ast.Eq, ast.NotEq]): + # Equals and not-equals are commutative + if compareASTs(l, r) > 0: + a.left, a.comparators[0] = a.comparators[0], a.left + elif (type(a.ops[0]) in [ast.Gt, ast.GtE]): + # We'll always use < and <=, just so everything's the same + a.ops = [reverse(a.ops[0])] + a.left, a.comparators[0] = a.comparators[0], a.left + elif (type(a.ops[0]) in [ast.In, ast.NotIn]): + if type(r) == ast.List: + # If it's a list of items, sort the list + # TODO: should we implement crashable sorting here? + for i in range(len(r.elts)): + if couldCrash(r.elts[i]): + break # don't sort if there'a a crash! + else: + r.elts = sorted(r.elts, key=functools.cmp_to_key(compareASTs)) + # Then remove duplicates + i = 0 + while i < len(r.elts) - 1: + if compareASTs(r.elts[i], r.elts[i+1], checkEquality=True) == 0: + r.elts.pop(i+1) + else: + i += 1 + return a + elif type(a) == ast.Call: + if type(a.func) == ast.Name: + # These functions are commutative and show up a lot + if a.func.id in ["min", "max"]: + crashable = False + for i in range(len(a.args)): + a.args[i] = orderCommutativeOperations(a.args[i]) + if couldCrash(a.args[i]) or containsTokenStepString(a.args[i]): + crashable = True + # TODO: crashable sorting here? + if not crashable: + a.args = sorted(a.args, key=functools.cmp_to_key(compareASTs)) + return a + return applyToChildren(a, orderCommutativeOperations) + +def deMorganize(a): + """Apply De Morgan's law throughout the code in order to canonicalize""" + if not isinstance(a, ast.AST): + return a + # We only care about statements beginning with not + if type(a) == ast.UnaryOp and type(a.op) == ast.Not: + oper = a.operand + top = type(oper) + + # not (blah and gah) == (not blah or not gah) + if top == ast.BoolOp: + oper.op = negate(oper.op) + for i in range(len(oper.values)): + oper.values[i] = deMorganize(negate(oper.values[i])) + oper.negated = not oper.negated if hasattr(oper, "negated") else True + transferMetaData(a, oper) + return oper + # not a < b == a >= b + elif top == ast.Compare: + oper.left = deMorganize(oper.left) + oper.ops = [negate(oper.ops[0])] + oper.comparators = [deMorganize(oper.comparators[0])] + oper.negated = not oper.negated if hasattr(oper, "negated") else True + transferMetaData(a, oper) + return oper + # not not blah == blah + elif top == ast.UnaryOp and type(oper.op) == ast.Not: + oper.operand = deMorganize(oper.operand) + if eventualType(oper.operand) != bool: + return a + oper.operand.negated = not oper.operand.negated if hasattr(oper.operand, "negated") else True + return oper.operand + elif top == ast.NameConstant: + if oper.value in [True, False]: + oper = negate(oper) + transferMetaData(a, oper) + return oper + elif oper.value == None: + tmp = ast.NameConstant(True) + transferMetaData(a, tmp) + tmp.negated = True + return tmp + else: + log("Unknown NameConstant: " + str(oper.value), "bug") + + return applyToChildren(a, deMorganize) + +##### CLEANUP FUNCTIONS ##### + +def cleanupEquals(a): + """Gets rid of silly blah == True statements that students make""" + if not isinstance(a, ast.AST): + return a + if type(a) == ast.Call: + a.func = cleanupEquals(a.func) + for i in range(len(a.args)): + # But test expressions don't carry through to function arguments + a.args[i] = cleanupEquals(a.args[i]) + return a + elif type(a) == ast.Compare and type(a.ops[0]) in [ast.Eq, ast.NotEq]: + l = a.left = cleanupEquals(a.left) + r = cleanupEquals(a.comparators[0]) + a.comparators = [r] + if type(l) == ast.NameConstant and l.value in [True, False]: + (l,r) = (r,l) + # If we have (boolean expression) == True + if type(r) == ast.NameConstant and r.value in [True, False] and (eventualType(l) == bool): + # Matching types + if (type(a.ops[0]) == ast.Eq and r.value == True) or \ + (type(a.ops[0]) == ast.NotEq and r.value == False): + transferMetaData(a, l) # make sure to keep the original location + return l + else: + tmp = ast.UnaryOp(ast.Not(addedNotOp=True), l) + transferMetaData(a, tmp) + return tmp + else: + return a + else: + return applyToChildren(a, cleanupEquals) + +def cleanupBoolOps(a): + """When possible, combine adjacent boolean expressions""" + """Note- we are assuming that all ops are the first op (as is done in the simplify function)""" + if not isinstance(a, ast.AST): + return a + if type(a) == ast.BoolOp: + allTypesWork = True + for i in range(len(a.values)): + a.values[i] = cleanupBoolOps(a.values[i]) + if eventualType(a.values[i]) != bool or hasattr(a.values[i], "multiComp"): + allTypesWork = False + + # We can't reduce if the types aren't all booleans + if not allTypesWork: + return a + + i = 0 + while i < len(a.values) - 1: + current = a.values[i] + next = a.values[i+1] + # (a and b and c and d) or (a and e and d) == a and ((b and c) or e) and d + if type(current) == type(next) == ast.BoolOp: + if type(current.op) == type(next.op): + minlength = min(len(current.values), len(next.values)) # shortest length + + # First, check for all identical values from the front + j = 0 + while j < minlength: + if compareASTs(current.values[j], next.values[j], checkEquality=True) != 0: + break + j += 1 + + # Same values in both, so get rid of the latter line + if j == len(current.values) == len(next.values): + a.values.pop(i+1) + continue + i += 1 + ### If reduced to one item, just return that item + return a.values[0] if (len(a.values) == 1) else a + return applyToChildren(a, cleanupBoolOps) + +def cleanupRanges(a): + """Remove any range shenanigans, because Python lets you include unneccessary values""" + if not isinstance(a, ast.AST): + return a + if type(a) == ast.Call: + if type(a.func) == ast.Name: + if a.func.id in ["range"]: + if len(a.args) == 3: + # The step defaults to 1! + if type(a.args[2]) == ast.Num and a.args[2].n == 1: + a.args = a.args[:-1] + if len(a.args) == 2: + # The start defaults to 0! + if type(a.args[0]) == ast.Num and a.args[0].n == 0: + a.args = a.args[1:] + return applyToChildren(a, cleanupRanges) + +def cleanupSlices(a): + """Remove any slice shenanigans, because Python lets you include unneccessary values""" + if not isinstance(a, ast.AST): + return a + if type(a) == ast.Subscript: + if type(a.slice) == ast.Slice: + # Lower defaults to 0 + if a.slice.lower != None and type(a.slice.lower) == ast.Num and a.slice.lower.n == 0: + a.slice.lower = None + # Upper defaults to len(value) + if a.slice.upper != None and type(a.slice.upper) == ast.Call and \ + type(a.slice.upper.func) == ast.Name and a.slice.upper.func.id == "len": + if compareASTs(a.value, a.slice.upper.args[0], checkEquality=True) == 0: + a.slice.upper = None + # Step defaults to 1 + if a.slice.step != None and type(a.slice.step) == ast.Num and a.slice.step.n == 1: + a.slice.step = None + return applyToChildren(a, cleanupSlices) + +def cleanupTypes(a): + """Remove any unneccessary type mappings""" + if not isinstance(a, ast.AST): + return a + # No need to cast something if it'll be changed anyway by a binary operation + if type(a) == ast.BinOp: + a.left = cleanupTypes(a.left) + a.right = cleanupTypes(a.right) + # Ints become floats naturally + if eventualType(a.left) == eventualType(a.right) == float: + if type(a.right) == ast.Call and type(a.right.func) == ast.Name and \ + a.right.func.id == "float" and len(a.right.args) == 1 and len(a.right.keywords) == 0 and \ + eventualType(a.right.args[0]) in [int, float]: + a.right = a.right.args[0] + elif type(a.left) == ast.Call and type(a.left.func) == ast.Name and \ + a.left.func.id == "float" and len(a.left.args) == 1 and len(a.left.keywords) == 0 and \ + eventualType(a.left.args[0]) in [int, float]: + a.left = a.left.args[0] + return a + elif type(a) == ast.Call and type(a.func) == ast.Name and len(a.args) == 1 and len(a.keywords) == 0: + a.func = cleanupTypes(a.func) + a.args = [cleanupTypes(a.args[0])] + # If the type already matches, no need to cast it + funName = a.func.id + argType = eventualType(a.args[0]) + if type(a.func) == ast.Name: + if (funName == "float" and argType == float) or \ + (funName == "int" and argType == int) or \ + (funName == "bool" and argType == bool) or \ + (funName == "str" and argType == str): + return a.args[0] + return applyToChildren(a, cleanupTypes) + +def turnPositive(a): + """Take a negative number and make it positive""" + if type(a) == ast.UnaryOp and type(a.op) == ast.USub: + return a.operand + elif type(a) == ast.Num and type(a.n) != complex and a.n < 0: + a.n = a.n * -1 + return a + else: + log("transformations\tturnPositive\tFailure: " + str(a), "bug") + return a + +def isNegative(a): + """Is the give number negative?""" + if type(a) == ast.UnaryOp and type(a.op) == ast.USub: + return True + elif type(a) == ast.Num and type(a.n) != complex and a.n < 0: + return True + else: + return False + +def cleanupNegations(a): + """Remove unneccessary negations""" + if not isinstance(a, ast.AST): + return a + elif type(a) == ast.BinOp: + a.left = cleanupNegations(a.left) + a.right = cleanupNegations(a.right) + + if type(a.op) == ast.Add: + # x + (-y) + if isNegative(a.right): + a.right = turnPositive(a.right) + a.op = ast.Sub(global_id=a.op.global_id, num_negated=True) + return a + # (-x) + y + elif isNegative(a.left): + if couldCrash(a.left) and couldCrash(a.right): + return a # can't switch if it'll change the message + else: + (a.left,a.right) = (a.right,turnPositive(a.left)) + a.op = ast.Sub(global_id=a.op.global_id, num_negated=True) + return a + elif type(a.op) == ast.Sub: + # x - (-y) + if isNegative(a.right): + a.right = turnPositive(a.right) + a.op = ast.Add(global_id=a.op.global_id, num_negated=True) + return a + elif type(a.right) == ast.BinOp: + # x - (y + z) = x + (-y - z) + if type(a.right.op) == ast.Add: + a.right.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.right.left, addedOther=True)) + a.right.op = ast.Sub(global_id=a.right.op.global_id, num_negated=True) + a.op = ast.Add(global_id=a.op.global_id, num_negated=True) + return a + # x - (y - z) = x + (-y + z) = x + (z - y) + elif type(a.right.op) == ast.Sub: + if couldCrash(a.right.left) and couldCrash(a.right.right): + a.right.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.right.left, addedOther=True)) + a.right.op = ast.Add(global_id=a.right.op.global_id, num_negated=True) + a.op = ast.Add(global_id=a.op.global_id, num_negated=True) + return a + else: + (a.right.left, a.right.right) = (a.right.right, a.right.left) + a.op = ast.Add(global_id=a.op.global_id, num_negated=True) + return a + # Move negations to the outer part of multiplications + elif type(a.op) == ast.Mult: + # -x * -y + if isNegative(a.left) and isNegative(a.right): + a.left = turnPositive(a.left) + a.right = turnPositive(a.right) + return a + # -x * y = -(x*y) + elif isNegative(a.left): + if eventualType(a.right) in [int, float]: + a.left = turnPositive(a.left) + return cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a, addedOther=True)) + # x * -y = -(x*y) + elif isNegative(a.right): + if eventualType(a.left) in [int, float]: + a.right = turnPositive(a.right) + return cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a, addedOther=True)) + elif type(a.op) in [ast.Div, ast.FloorDiv]: + if isNegative(a.left) and isNegative(a.right): + a.left = turnPositive(a.left) + a.right = turnPositive(a.right) + return a + return a + elif type(a) == ast.UnaryOp: + a.operand = cleanupNegations(a.operand) + if type(a.op) == ast.USub: + if type(a.operand) == ast.BinOp: + # -(x + y) = -x - y + if type(a.operand.op) == ast.Add: + a.operand.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.operand.left, addedOther=True)) + a.operand.op = ast.Sub(global_id=a.operand.op.global_id, num_negated=True) + transferMetaData(a, a.operand) + return a.operand + # -(x - y) = -x + y = y - x + elif type(a.operand.op) == ast.Sub: + if couldCrash(a.operand.left) and couldCrash(a.operand.right): + a.operand.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.operand.left, addedOther=True)) + a.operand.op = ast.Add(global_id=a.operand.op.global_id, num_negated=True) + transferMetaData(a, a.operand) + return a.operand + else: + (a.operand.left,a.operand.right) = (a.operand.right,a.operand.left) + transferMetaData(a, a.operand) + return a.operand + return a + # Special case for absolute value + elif type(a) == ast.Call and type(a.func) == ast.Name and a.func.id == "abs" and len(a.args) == 1: + a.args[0] = cleanupNegations(a.args[0]) + if type(a.args[0]) == ast.UnaryOp and type(a.args[0].op) == ast.USub: + a.args[0] = a.args[0].operand + elif type(a.args[0]) == ast.BinOp and type(a.args[0].op) == ast.Sub: + if not (couldCrash(a.args[0].left) and couldCrash(a.args[0].right)) and \ + compareASTs(a.args[0].left, a.args[0].right) > 0: + (a.args[0].left,a.args[0].right) = (a.args[0].right,a.args[0].left) + return a + else: + return applyToChildren(a, cleanupNegations) + +### CONDITIONAL TRANSFORMATIONS ### + +def combineConditionals(a): + """When possible, combine conditional branches""" + if not isinstance(a, ast.AST): + return a + elif type(a) == ast.If: + for i in range(len(a.body)): + a.body[i] = combineConditionals(a.body[i]) + for i in range(len(a.orelse)): + a.orelse[i] = combineConditionals(a.orelse[i]) + + # if a: if b: x can be - if a and b: x + if (len(a.orelse) == 0) and (len(a.body) == 1) and \ + (type(a.body[0]) == ast.If) and (len(a.body[0].orelse) == 0): + a.test = ast.BoolOp(ast.And(combinedConditionalOp=True), [a.test, a.body[0].test], combinedConditional=True) + a.body = a.body[0].body + # if a: x elif b: x can be - if a or b: x + elif (len(a.orelse) == 1) and \ + (type(a.orelse[0]) == ast.If) and (len(a.orelse[0].orelse) == 0): + if compareASTs(a.body, a.orelse[0].body, checkEquality=True) == 0: + a.test = ast.BoolOp(ast.Or(combinedConditionalOp=True), [a.test, a.orelse[0].test], combinedConditional=True) + a.orelse = [] + return a + else: + return applyToChildren(a, combineConditionals) + +def staticVars(l, vars): + """Determines whether the given lines change the given variables""" + # First, if one of the variables can be modified, there might be a problem + mutableVars = [] + for var in vars: + if (not (hasattr(var, "type") and (var.type in [int, float, str, bool]))): + mutableVars.append(var) + + for i in range(len(l)): + if type(l[i]) == ast.Assign: + for var in vars: + if var.id in allVariableNamesUsed(l[i].targets[0]): + return False + elif type(l[i]) == ast.AugAssign: + for var in vars: + if var.id in allVariableNamesUsed(l[i].target): + return False + elif type(l[i]) in [ast.If, ast.While]: + if not (staticVars(l[i].body, vars) and staticVars(l[i].orelse, vars)): + return False + elif type(l[i]) == ast.For: + for var in vars: + if var.id in allVariableNamesUsed(l[i].target): + return False + if not (staticVars(l[i].body, vars) and staticVars(l[i].orelse, vars)): + return False + elif type(l[i]) in [ast.FunctionDef, ast.ClassDef, ast.Try, ast.With]: + log("transformations\tstaticVars\tMissing type: " + str(type(l[i])), "bug") + + # If a mutable variable is used, we can't trust it + for var in mutableVars: + if var.id in allVariableNamesUsed(l[i]): + return False + return True + +def getIfBranches(a): + """Gets all the branches of an if statement. Will only work if each else has a single line""" + if type(a) != ast.If: + return None + + if len(a.orelse) == 0: + return [a] + elif len(a.orelse) == 1: + tmp = getIfBranches(a.orelse[0]) + if tmp == None: + return None + return [a] + tmp + else: + return None + +def allVariablesUsed(a): + """Gathers all the variable names used in the ast""" + if type(a) == list: + variables = [] + for x in a: + variables += allVariablesUsed(x) + return variables + + if not isinstance(a, ast.AST): + return [] + elif type(a) == ast.Name: + return [a] + elif type(a) == ast.Assign: + variables = allVariablesUsed(a.value) + for target in a.targets: + if type(target) == ast.Name: + pass + elif type(target) in [ast.Tuple, ast.List]: + for elt in target.elts: + if type(elt) == ast.Name: + pass + else: + variables += allVariablesUsed(elt) + else: + variables += allVariablesUsed(target) + return variables + else: + variables = [] + for child in ast.iter_child_nodes(a): + variables += allVariablesUsed(child) + return variables + +def conditionalRedundancy(a): + """When possible, remove redundant lines from conditionals and combine conditionals.""" + if type(a) == ast.Module: + for i in range(len(a.body)): + if type(a.body[i]) == ast.FunctionDef: + a.body[i] = conditionalRedundancy(a.body[i]) + return a + elif type(a) == ast.FunctionDef: + a.body = conditionalRedundancy(a.body) + return a + + if type(a) == list: + i = 0 + while i < len(a): + stmt = a[i] + if type(stmt) == ast.If: + stmt.body = conditionalRedundancy(stmt.body) + stmt.orelse = conditionalRedundancy(stmt.orelse) + + # If a line appears in both, move it outside the conditionals + if len(stmt.body) > 0 and len(stmt.orelse) > 0 and compareASTs(stmt.body[-1], stmt.orelse[-1], checkEquality=True) == 0: + nextLine = stmt.body[-1] + nextLine.second_global_id = stmt.orelse[-1].global_id + stmt.body = stmt.body[:-1] + stmt.orelse = stmt.orelse[:-1] + stmt.moved_line = nextLine.global_id + # Remove the if statement if both if and else are empty + if len(stmt.body) == 0 and len(stmt.orelse) == 0: + newLine = ast.Expr(stmt.test) + transferMetaData(stmt, newLine) + a[i:i+1] = [newLine, nextLine] + # Switch if and else if if is empty + elif len(stmt.body) == 0: + stmt.test = ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True) + stmt.body = stmt.orelse + stmt.orelse = [] + a[i:i+1] = [stmt, nextLine] + else: + a[i:i+1] = [stmt, nextLine] + continue # skip incrementing so that we check the conditional again + # Join adjacent, disjoint ifs + elif i+1 < len(a) and type(a[i+1]) == ast.If: + branches1 = getIfBranches(stmt) + branches2 = getIfBranches(a[i+1]) + if branches1 != None and branches2 != None: + # First, check whether any vars used in the second set of branches will be changed by the first set + testVars = sum(map(lambda b : allVariablesUsed(b.test), branches2), []) + for branch in branches1: + if not staticVars(branch.body, testVars): + break + else: + branchCombos = [(x,y) for y in branches2 for x in branches1] + for (branch1,branch2) in branchCombos: + if not areDisjoint(branch1.test, branch2.test): + break + else: + # We know the last else branch is empty- fill it with the next tree! + branches1[-1].orelse = [a[i+1]] + a.pop(i+1) + continue # check this conditional again + elif type(stmt) == ast.FunctionDef: + stmt.body = conditionalRedundancy(stmt.body) + elif type(stmt) in [ast.While, ast.For]: + stmt.body = conditionalRedundancy(stmt.body) + stmt.orelse = conditionalRedundancy(stmt.orelse) + elif type(stmt) == ast.ClassDef: + for x in range(len(stmt.body)): + if type(stmt.body[x]) == ast.FunctionDef: + stmt.body[x] = conditionalRedundancy(stmt.body[x]) + elif type(stmt) == ast.Try: + stmt.body = conditionalRedundancy(stmt.body) + for x in range(len(stmt.handlers)): + stmt.handlers[x].body = conditionalRedundancy(stmt.handlers[x].body) + stmt.orelse = conditionalRedundancy(stmt.orelse) + stmt.finalbody = conditionalRedundancy(stmt.finalbody) + elif type(stmt) == ast.With: + stmt.body = conditionalRedundancy(stmt.body) + else: + pass + i += 1 + return a + else: + log("transformations\tconditionalRedundancy\tStrange type: " + str(type(a)), "bug") + +def collapseConditionals(a): + """When possible, combine adjacent conditionals""" + if type(a) == ast.Module: + for i in range(len(a.body)): + if type(a.body[i]) == ast.FunctionDef: + a.body[i] = collapseConditionals(a.body[i]) + return a + elif type(a) == ast.FunctionDef: + a.body = collapseConditionals(a.body) + return a + + if type(a) == list: + l = a + i = len(l) - 1 + + # Go through the lines backwards, since we're collapsing conditionals upwards + while i >= 0: + stmt = l[i] + if type(stmt) == ast.If: + stmt.body = collapseConditionals(stmt.body) + stmt.orelse = collapseConditionals(stmt.orelse) + + # First, check to see if we can collapse across the if and its else + if len(l[i].body) == 1 and len(l[i].orelse) == 1: + ifLine = l[i].body[0] + elseLine = l[i].orelse[0] + + # This only works for Assign and Return + if type(ifLine) == type(elseLine) == ast.Assign and \ + compareASTs(ifLine.targets, elseLine.targets, checkEquality=True) == 0: + pass + elif type(ifLine) == ast.Return and type(elseLine) == ast.Return: + pass + else: + i -= 1 + continue # skip past this + + if type(ifLine.value) == type(elseLine.value) == ast.Name and \ + ifLine.value.id in ['True', 'False'] and elseLine.value.id in ['True', 'False']: + if ifLine.value.id == elseLine.value.id: + # If they both return the same thing, just replace the if altogether. + # But keep the test in case it crashes- we'll remove it later + ifLine.global_id = None # we're replacing the whole if statement + l[i:i+1] = [ast.Expr(l[i].test, addedOther=True, moved_line=ifLine.global_id), ifLine] + elif eventualType(l[i].test) == bool: + testVal = l[i].test + if ifLine.value.id == 'True': + newVal = testVal + else: + newVal = ast.UnaryOp(ast.Not(addedNotOp=True), testVal, negated=True, collapsedExpr=True) + + if type(ifLine) == ast.Assign: + newLine = ast.Assign(ifLine.targets, newVal) + else: + newLine = ast.Return(newVal) + transferMetaData(l[i], newLine) + l[i] = newLine + # Next, check to see if we can collapse across the if and surrounding lines + elif len(l[i].body) == 1 and len(l[i].orelse) == 0: + ifLine = l[i].body[0] + # First, check to see if the current and prior have the same return bodies + if i != 0 and type(l[i-1]) == ast.If and \ + len(l[i-1].body) == 1 and len(l[i-1].orelse) == 0 and \ + type(ifLine) == ast.Return and compareASTs(ifLine, l[i-1].body[0], checkEquality=True) == 0: + # If they do, combine their tests with an Or and get rid of this line + l[i-1].test = ast.BoolOp(ast.Or(combinedConditionalOp=True), [l[i-1].test, l[i].test], combinedConditional=True) + l[i-1].second_global_id = l[i].global_id + l.pop(i) + # Then, check whether the current and latter lines have the same returns + elif i != len(l) - 1 and type(ifLine) == type(l[i+1]) == ast.Return and \ + type(ifLine.value) == type(l[i+1].value) == ast.Name and \ + ifLine.value.id in ['True', 'False'] and l[i+1].value.id in ['True', 'False']: + if ifLine.value.id == l[i+1].value.id: + # No point in keeping the if line- just use the return + l[i] = ast.Expr(l[i].test, addedOther=True) + else: + if eventualType(l[i].test) == bool: + testVal = l[i].test + if ifLine.value.id == 'True': + newLine = ast.Return(testVal) + else: + newLine = ast.Return(ast.UnaryOp(ast.Not(addedNotOp=True), testVal, negated=True, collapsedExpr=True)) + transferMetaData(l[i], newLine) + l[i] = newLine + l.pop(i+1) # get rid of the extra return + elif type(stmt) == ast.FunctionDef: + stmt.body = collapseConditionals(stmt.body) + elif type(stmt) in [ast.While, ast.For]: + stmt.body = collapseConditionals(stmt.body) + stmt.orelse = collapseConditionals(stmt.orelse) + elif type(stmt) == ast.ClassDef: + for i in range(len(stmt.body)): + if type(stmt.body[i]) == ast.FunctionDef: + stmt.body[i] = collapseConditionals(stmt.body[i]) + elif type(stmt) == ast.Try: + stmt.body = collapseConditionals(stmt.body) + for i in range(len(stmt.handlers)): + stmt.handlers[i].body = collapseConditionals(stmt.handlers[i].body) + stmt.orelse = collapseConditionals(stmt.orelse) + stmt.finalbody = collapseConditionals(stmt.finalbody) + elif type(stmt) == ast.With: + stmt.body = collapseConditionals(stmt.body) + else: + pass + i -= 1 + return l + else: + log("transformations\tcollapseConditionals\tStrange type: " + str(type(a)), "bug") |