summaryrefslogtreecommitdiff
path: root/canonicalize
diff options
context:
space:
mode:
Diffstat (limited to 'canonicalize')
-rw-r--r--canonicalize/COPYING21
-rw-r--r--canonicalize/__init__.py95
-rw-r--r--canonicalize/astTools.py1426
-rw-r--r--canonicalize/display.py570
-rw-r--r--canonicalize/namesets.py463
-rw-r--r--canonicalize/tools.py15
-rw-r--r--canonicalize/transformations.py2775
7 files changed, 5365 insertions, 0 deletions
diff --git a/canonicalize/COPYING b/canonicalize/COPYING
new file mode 100644
index 0000000..2465b95
--- /dev/null
+++ b/canonicalize/COPYING
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2017 Kelly Rivers
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/canonicalize/__init__.py b/canonicalize/__init__.py
new file mode 100644
index 0000000..481b6e2
--- /dev/null
+++ b/canonicalize/__init__.py
@@ -0,0 +1,95 @@
+import ast
+import copy
+import uuid
+
+from .transformations import *
+from .display import printFunction
+from .astTools import deepcopy
+
+def giveIds(a, idCounter=0):
+ if isinstance(a, ast.AST):
+ if type(a) in [ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param]:
+ return # skip these
+ a.global_id = uuid.uuid1()
+ idCounter += 1
+ for field in a._fields:
+ child = getattr(a, field)
+ if type(child) == list:
+ for i in range(len(child)):
+ # Get rid of aliased items
+ if hasattr(child[i], "global_id"):
+ child[i] = copy.deepcopy(child[i])
+ giveIds(child[i], idCounter)
+ else:
+ # Get rid of aliased items
+ if hasattr(child, "global_id"):
+ child = copy.deepcopy(child)
+ setattr(a, field, child)
+ giveIds(child, idCounter)
+
+def getCanonicalForm(tree, problem_name=None, given_names=None, argTypes=None, imports=None):
+ if imports == None:
+ imports = []
+
+ transformations = [
+ constantFolding,
+
+ cleanupEquals,
+ cleanupBoolOps,
+ cleanupRanges,
+ cleanupSlices,
+ cleanupTypes,
+ cleanupNegations,
+
+ conditionalRedundancy,
+ combineConditionals,
+ collapseConditionals,
+
+ copyPropagation,
+
+ deMorganize,
+ orderCommutativeOperations,
+
+ deadCodeRemoval
+ ]
+
+ varmap = {}
+ tree = propogateMetadata(tree, {'foo': ['int', 'int']}, varmap, [0])
+
+ tree = simplify(tree)
+ tree = anonymizeNames(tree, given_names, imports)
+
+ oldTree = None
+ while compareASTs(oldTree, tree, checkEquality=True) != 0:
+ oldTree = deepcopy(tree)
+ helperFolding(tree, problem_name, imports)
+ for t in transformations:
+ tree = t(tree)
+ return tree
+
+def canonicalize(code):
+ tree = ast.parse(code)
+ giveIds(tree)
+
+ new_tree = getCanonicalForm(tree, problem_name='foo', given_names=['foo'])
+ new_code = printFunction(new_tree)
+
+ return new_code
+
+
+if __name__ == '__main__':
+ code = '''
+def isPositive(x):
+ return x > 0
+
+def bar(x, y):
+ return x - y
+
+def foo(x, y):
+ if isPositive(y):
+ return x*y
+ if isPositive(x):
+ return x
+ return -x
+'''
+ print(canonicalize(code))
diff --git a/canonicalize/astTools.py b/canonicalize/astTools.py
new file mode 100644
index 0000000..b84b2c5
--- /dev/null
+++ b/canonicalize/astTools.py
@@ -0,0 +1,1426 @@
+import ast, copy, pickle
+from .tools import log
+from .namesets import *
+from .display import printFunction
+
+def cmp(a, b):
+ if type(a) == type(b) == complex:
+ return (a.real > b.real) - (a.real < b.real)
+ return (a > b) - (a < b)
+
+def builtInName(id):
+ """Determines whether the given id is a built-in name"""
+ if id in builtInNames + exceptionClasses:
+ return True
+ elif id in builtInFunctions.keys():
+ return True
+ elif id in list(allPythonFunctions.keys()) + supportedLibraries:
+ return False
+
+def importedName(id, importList):
+ for imp in importList:
+ if type(imp) == ast.Import:
+ for name in imp.names:
+ if hasattr(name, "asname") and name.asname != None:
+ if id == name.asname:
+ return True
+ else:
+ if id == name.name:
+ return True
+ elif type(imp) == ast.ImportFrom:
+ if hasattr(imp, "module"):
+ if imp.module in supportedLibraries:
+ libMap = libraryMap[imp.module]
+ for name in imp.names:
+ if hasattr(name, "asname") and name.asname != None:
+ if id == name.asname:
+ return True
+ else:
+ if id == name.name:
+ return True
+ else:
+ log("astTools\timportedName\tUnsupported library: " + printFunction(imp), "bug")
+
+ else:
+ log("astTools\timportedName\tWhy no module? " + printFunction(imp), "bug")
+ return False
+
+def isConstant(x):
+ """Determine whether the provided AST is a constant"""
+ return (type(x) in [ast.Num, ast.Str, ast.Bytes, ast.NameConstant])
+
+def isIterableType(t):
+ """Can the given type be iterated over"""
+ return t in [ dict, list, set, str, bytes, tuple ]
+
+def isStatement(a):
+ """Determine whether the given node is a statement (vs an expression)"""
+ return type(a) in [ ast.Module, ast.Interactive, ast.Expression, ast.Suite,
+ ast.FunctionDef, ast.ClassDef, ast.Return, ast.Delete,
+ ast.Assign, ast.AugAssign, ast.For, ast.While,
+ ast.If, ast.With, ast.Raise, ast.Try,
+ ast.Assert, ast.Import, ast.ImportFrom, ast.Global,
+ ast.Expr, ast.Pass, ast.Break, ast.Continue ]
+
+def codeLength(a):
+ """Returns the number of characters in this AST"""
+ if type(a) == list:
+ return sum([codeLength(x) for x in a])
+ return len(printFunction(a))
+
+def applyToChildren(a, f):
+ """Apply the given function to all the children of a"""
+ if a == None:
+ return a
+ for field in a._fields:
+ child = getattr(a, field)
+ if type(child) == list:
+ i = 0
+ while i < len(child):
+ temp = f(child[i])
+ if type(temp) == list:
+ child = child[:i] + temp + child[i+1:]
+ i += len(temp)
+ else:
+ child[i] = temp
+ i += 1
+ else:
+ child = f(child)
+ setattr(a, field, child)
+ return a
+
+def occursIn(sub, super):
+ """Does the first AST occur as a subtree of the second?"""
+ superStatementTypes = [ ast.Module, ast.Interactive, ast.Suite,
+ ast.FunctionDef, ast.ClassDef, ast.For,
+ ast.While, ast.If, ast.With, ast.Try,
+ ast.ExceptHandler ]
+ if (not isinstance(super, ast.AST)):
+ return False
+ if type(sub) == type(super) and compareASTs(sub, super, checkEquality=True) == 0:
+ return True
+ # we know that a statement can never occur in an expression
+ # (or in a non-statement-holding statement), so cut the search off now to save time.
+ if isStatement(sub) and type(super) not in superStatementTypes:
+ return False
+ for child in ast.iter_child_nodes(super):
+ if occursIn(sub, child):
+ return True
+ return False
+
+def countOccurances(a, value):
+ """How many instances of this node type appear in the AST?"""
+ if type(a) == list:
+ return sum([countOccurances(x, value) for x in a])
+ if not isinstance(a, ast.AST):
+ return 0
+
+ count = 0
+ for node in ast.walk(a):
+ if isinstance(node, value):
+ count += 1
+ return count
+
+def countVariables(a, id):
+ """Count the number of times the given variable appears in the AST"""
+ if type(a) == list:
+ return sum([countVariables(x, id) for x in a])
+ if not isinstance(a, ast.AST):
+ return 0
+
+ count = 0
+ for node in ast.walk(a):
+ if type(node) == ast.Name and node.id == id:
+ count += 1
+ return count
+
+def gatherAllNames(a, keep_orig=True):
+ """Gather all names in the tree (variable or otherwise).
+ Names are returned along with their original names
+ (which are used in variable mapping)"""
+ if type(a) == list:
+ allIds = set()
+ for line in a:
+ allIds |= gatherAllNames(line)
+ return allIds
+ if not isinstance(a, ast.AST):
+ return set()
+
+ allIds = set()
+ for node in ast.walk(a):
+ if type(node) == ast.Name:
+ origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None
+ allIds |= set([(node.id, origName)])
+ return allIds
+
+def gatherAllVariables(a, keep_orig=True):
+ """Gather all variable names in the tree. Names are returned along
+ with their original names (which are used in variable mapping)"""
+ if type(a) == list:
+ allIds = set()
+ for line in a:
+ allIds |= gatherAllVariables(line)
+ return allIds
+ if not isinstance(a, ast.AST):
+ return set()
+
+ allIds = set()
+ for node in ast.walk(a):
+ if type(node) == ast.Name or type(node) == ast.arg:
+ currentId = node.id if type(node) == ast.Name else node.arg
+ # Only take variables
+ if not (builtInName(currentId) or hasattr(node, "dontChangeName")):
+ origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None
+ if (currentId, origName) not in allIds:
+ for pair in allIds:
+ if pair[0] == currentId:
+ if pair[1] == None:
+ allIds -= {pair}
+ allIds |= {(currentId, origName)}
+ elif origName == None:
+ pass
+ else:
+ log("astTools\tgatherAllVariables\tConflicting originalIds? " + pair[0] + " : " + pair[1] + " , " + origName + "\n" + printFunction(a), "bug")
+ break
+ else:
+ allIds |= {(currentId, origName)}
+ return allIds
+
+def gatherAllParameters(a, keep_orig=True):
+ """Gather all parameters in the tree. Names are returned along
+ with their original names (which are used in variable mapping)"""
+ if type(a) == list:
+ allIds = set()
+ for line in a:
+ allIds |= gatherAllVariables(line)
+ return allIds
+ if not isinstance(a, ast.AST):
+ return set()
+
+ allIds = set()
+ for node in ast.walk(a):
+ if type(node) == ast.arg:
+ origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None
+ allIds |= set([(node.arg, origName)])
+ return allIds
+
+def gatherAllHelpers(a, restricted_names):
+ """Gather all helper function names in the tree that have been anonymized"""
+ if type(a) != ast.Module:
+ return set()
+ helpers = set()
+ for item in a.body:
+ if type(item) == ast.FunctionDef:
+ if not hasattr(item, "dontChangeName") and item.name not in restricted_names: # this got anonymized
+ origName = item.originalId if hasattr(item, "originalId") else None
+ helpers |= set([(item.name, origName)])
+ return helpers
+
+def gatherAllFunctionNames(a):
+ """Gather all helper function names in the tree that have been anonymized"""
+ if type(a) != ast.Module:
+ return set()
+ helpers = set()
+ for item in a.body:
+ if type(item) == ast.FunctionDef:
+ origName = item.originalId if hasattr(item, "originalId") else None
+ helpers |= set([(item.name, origName)])
+ return helpers
+
+def gatherAssignedVars(targets):
+ """Take a list of assigned variables and extract the names/subscripts/attributes"""
+ if type(targets) != list:
+ targets = [targets]
+ newTargets = []
+ for target in targets:
+ if type(target) in [ast.Tuple, ast.List]:
+ newTargets += gatherAssignedVars(target.elts)
+ elif type(target) in [ast.Name, ast.Subscript, ast.Attribute]:
+ newTargets.append(target)
+ else:
+ log("astTools\tgatherAssignedVars\tWeird Assign Type: " + str(type(target)),"bug")
+ return newTargets
+
+def gatherAssignedVarIds(targets):
+ """Just get the ids of Names"""
+ vars = gatherAssignedVars(targets)
+ return [y.id for y in filter(lambda x : type(x) == ast.Name, vars)]
+
+def getAllAssignedVarIds(a):
+ if not isinstance(a, ast.AST):
+ return []
+ ids = []
+ for child in ast.walk(a):
+ if type(child) == ast.Assign:
+ ids += gatherAssignedVarIds(child.targets)
+ elif type(child) == ast.AugAssign:
+ ids += gatherAssignedVarIds([child.target])
+ elif type(child) == ast.For:
+ ids += gatherAssignedVarIds([child.target])
+ return ids
+
+def getAllAssignedVars(a):
+ if not isinstance(a, ast.AST):
+ return []
+ vars = []
+ for child in ast.walk(a):
+ if type(child) == ast.Assign:
+ vars += gatherAssignedVars(child.targets)
+ elif type(child) == ast.AugAssign:
+ vars += gatherAssignedVars([child.target])
+ elif type(child) == ast.For:
+ vars += gatherAssignedVars([child.target])
+ return vars
+
+def getAllFunctions(a):
+ """Collects all the functions in the given module"""
+ if not isinstance(a, ast.AST):
+ return []
+ functions = []
+ for child in ast.walk(a):
+ if type(child) == ast.FunctionDef:
+ functions.append(child.name)
+ return functions
+
+def getAllImports(a):
+ """Gather all imported module names"""
+ if not isinstance(a, ast.AST):
+ return []
+ imports = []
+ for child in ast.walk(a):
+ if type(child) == ast.Import:
+ for alias in child.names:
+ if alias.name in supportedLibraries:
+ imports.append(alias.asname if alias.asname != None else alias.name)
+ else:
+ log("astTools\tgetAllImports\tUnknown library: " + alias.name, "bug")
+ elif type(child) == ast.ImportFrom:
+ if child.module in supportedLibraries:
+ for alias in child.names: # these are all functions
+ if alias.name in libraryMap[child.module]:
+ imports.append(alias.asname if alias.asname != None else alias.name)
+ else:
+ log("astTools\tgetAllImports\tUnknown import from name: " + \
+ child.module + "," + alias.name, "bug")
+ else:
+ log("astTools\tgetAllImports\tUnknown library: " + child.module, "bug")
+ return imports
+
+def getAllImportStatements(a):
+ if not isinstance(a, ast.AST):
+ return []
+ imports = []
+ for child in ast.walk(a):
+ if type(child) == ast.Import:
+ imports.append(child)
+ elif type(child) == ast.ImportFrom:
+ imports.append(child)
+ return imports
+
+def getAllGlobalNames(a):
+ # Finds all names that can be accessed at the global level in the AST
+ if type(a) != ast.Module:
+ return []
+ names = []
+ for obj in a.body:
+ if type(obj) in [ast.FunctionDef, ast.ClassDef]:
+ names.append(obj.name)
+ elif type(obj) in [ast.Assign, ast.AugAssign]:
+ targets = obj.targets if type(obj) == ast.Assign else [obj.target]
+ for target in obj.targets:
+ if type(target) == ast.Name:
+ names.append(target.id)
+ elif type(target) in [ast.Tuple, ast.List]:
+ for elt in target.elts:
+ if type(elt) == ast.Name:
+ names.append(elt.id)
+ elif type(obj) in [ast.Import, ast.ImportFrom]:
+ for module in obj.names:
+ names.append(module.asname if module.asname != None else module.name)
+ return names
+
+def doBinaryOp(op, l, r):
+ """Perform the given AST binary operation on the values"""
+ top = type(op)
+ if top == ast.Add:
+ return l + r
+ elif top == ast.Sub:
+ return l - r
+ elif top == ast.Mult:
+ return l * r
+ elif top == ast.Div:
+ # Don't bother if this will be a really long float- it won't work properly!
+ # Also, in Python 3 this is floating division, so perform it accordingly.
+ val = 1.0 * l / r
+ if (val * 1e10 % 1.0) != 0:
+ raise Exception("Repeating Float")
+ return val
+ elif top == ast.Mod:
+ return l % r
+ elif top == ast.Pow:
+ return l ** r
+ elif top == ast.LShift:
+ return l << r
+ elif top == ast.RShift:
+ return l >> r
+ elif top == ast.BitOr:
+ return l | r
+ elif top == ast.BitXor:
+ return l ^ r
+ elif top == ast.BitAnd:
+ return l & r
+ elif top == ast.FloorDiv:
+ return l // r
+
+def doUnaryOp(op, val):
+ """Perform the given AST unary operation on the value"""
+ top = type(op)
+ if top == ast.Invert:
+ return ~ val
+ elif top == ast.Not:
+ return not val
+ elif top == ast.UAdd:
+ return val
+ elif top == ast.USub:
+ return -val
+
+def doCompare(op, left, right):
+ """Perform the given AST comparison on the values"""
+ top = type(op)
+ if top == ast.Eq:
+ return left == right
+ elif top == ast.NotEq:
+ return left != right
+ elif top == ast.Lt:
+ return left < right
+ elif top == ast.LtE:
+ return left <= right
+ elif top == ast.Gt:
+ return left > right
+ elif top == ast.GtE:
+ return left >= right
+ elif top == ast.Is:
+ return left is right
+ elif top == ast.IsNot:
+ return left is not right
+ elif top == ast.In:
+ return left in right
+ elif top == ast.NotIn:
+ return left not in right
+
+def num_negate(op):
+ top = type(op)
+ neg = not op.num_negated if hasattr(op, "num_negated") else True
+ if top == ast.Add:
+ newOp = ast.Sub()
+ elif top == ast.Sub:
+ newOp = ast.Add()
+ elif top in [ast.Mult, ast.Div, ast.Mod, ast.Pow, ast.LShift,
+ ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd, ast.FloorDiv]:
+ return None # can't negate this
+ elif top in [ast.Num, ast.Name]:
+ # this is a normal value, so put a - in front of it
+ newOp = ast.UnaryOp(ast.USub(addedNeg=True), op)
+ else:
+ log("astTools\tnum_negate\tUnusual type: " + str(top), "bug")
+ transferMetaData(op, newOp)
+ newOp.num_negated = neg
+ return newOp
+
+def negate(op):
+ """Return the negation of the provided operator"""
+ if op == None:
+ return None
+ top = type(op)
+ neg = not op.negated if hasattr(op, "negated") else True
+ if top == ast.And:
+ newOp = ast.Or()
+ elif top == ast.Or:
+ newOp = ast.And()
+ elif top == ast.Eq:
+ newOp = ast.NotEq()
+ elif top == ast.NotEq:
+ newOp = ast.Eq()
+ elif top == ast.Lt:
+ newOp = ast.GtE()
+ elif top == ast.GtE:
+ newOp = ast.Lt()
+ elif top == ast.Gt:
+ newOp = ast.LtE()
+ elif top == ast.LtE:
+ newOp = ast.Gt()
+ elif top == ast.Is:
+ newOp = ast.IsNot()
+ elif top == ast.IsNot:
+ newOp = ast.Is()
+ elif top == ast.In:
+ newOp = ast.NotIn()
+ elif top == ast.NotIn:
+ newOp = ast.In()
+ elif top == ast.NameConstant and op.value in [True, False]:
+ op.value = not op.value
+ op.negated = neg
+ return op
+ elif top == ast.Compare:
+ if len(op.ops) == 1:
+ op.ops[0] = negate(op.ops[0])
+ op.negated = neg
+ return op
+ else:
+ values = []
+ allOperands = [op.left] + op.comparators
+ for i in range(len(op.ops)):
+ values.append(ast.Compare(allOperands[i], [negate(op.ops[i])],
+ [allOperands[i+1]], multiCompPart=True))
+ newOp = ast.BoolOp(ast.Or(multiCompOp=True), values, multiComp=True)
+ elif top == ast.UnaryOp and type(op.op) == ast.Not and \
+ eventualType(op.operand) == bool: # this can mess things up type-wise
+ return op.operand
+ else:
+ # this is a normal value, so put a not around it
+ newOp = ast.UnaryOp(ast.Not(addedNot=True), op)
+ transferMetaData(op, newOp)
+ newOp.negated = neg
+ return newOp
+
+def couldCrash(a):
+ """Determines whether the given AST could possibly crash"""
+ typeCrashes = True # toggle based on whether you care about potential crashes caused by types
+ if not isinstance(a, ast.AST):
+ return False
+
+ if type(a) == ast.Try:
+ for handler in a.handlers:
+ for child in ast.iter_child_nodes(handler):
+ if couldCrash(child):
+ return True
+ for other in a.orelse:
+ for child in ast.iter_child_nodes(other):
+ if couldCrash(child):
+ return True
+ for line in a.finalbody:
+ for child in ast.iter_child_nodes(line):
+ if couldCrash(child):
+ return True
+ return False
+
+ # If any child could crash, this can crash
+ for child in ast.iter_child_nodes(a):
+ if couldCrash(child):
+ return True
+
+ if type(a) == ast.FunctionDef:
+ argNames = []
+ for arg in a.args.args:
+ if arg.arg in argNames: # conflicting arg names!
+ return True
+ else:
+ argNames.append(arg.arg)
+ if type(a) == ast.Assign:
+ for target in a.targets:
+ if type(target) != ast.Name: # can crash if it's a tuple and we can't unpack the value
+ return True
+ elif type(a) in [ast.For, ast.comprehension]: # check if the target or iter will break things
+ if type(a.target) not in [ast.Name, ast.Tuple, ast.List]:
+ return True
+ elif type(a.target) in [ast.Tuple, ast.List]:
+ for x in a.target.elts:
+ if type(x) != ast.Name:
+ return True
+ elif isIterableType(eventualType(a.iter)):
+ return True
+ elif type(a) == ast.Import:
+ for name in a.names:
+ if name not in supportedLibraries:
+ return True
+ elif type(a) == ast.ImportFrom:
+ if a.module not in supportedLibraries:
+ return True
+ if a.level != None:
+ return True
+ for name in a.names:
+ if name not in libraryMap[a.module]:
+ return True
+ elif type(a) == ast.BinOp:
+ l = eventualType(a.left)
+ r = eventualType(a.right)
+ if type(a.op) == ast.Add:
+ if not ((l == r == str) or (l in [int, float] and r in [int, float])):
+ return typeCrashes
+ elif type(a.op) == ast.Mult:
+ if not ((l == str and r == int) or (l == int and r == str) or \
+ (l in [int, float] and r in [int, float])):
+ return typeCrashes
+ elif type(a.op) in [ast.Sub, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd]:
+ if not (l in [int, float] and r in [int, float]):
+ return typeCrashes
+ elif type(a.op) == ast.Pow:
+ if not ((l in [int, float] and r == int) or \
+ (l in [int, float] and type(a.right) == ast.Num and \
+ type(a.right.n) != complex and \
+ (a.right.n >= 1 or a.right.n == 0 or a.right.n <= -1))):
+ return True
+ else: # ast.Div, ast.FloorDiv, ast.Mod
+ if type(a.right) == ast.Num and a.right.n != 0:
+ if l not in [int, float]:
+ return typeCrashes
+ else:
+ return True # Divide by zero error
+ elif type(a) == ast.UnaryOp:
+ if type(a.op) in [ast.UAdd, ast.USub]:
+ if eventualType(a.operand) not in [int, float]:
+ return typeCrashes
+ elif type(a.op) == ast.Invert:
+ if eventualType(a.operand) != int:
+ return typeCrashes
+ elif type(a) == ast.Compare:
+ if len(a.ops) != len(a.comparators):
+ return True
+ elif type(a.ops[0]) in [ast.In, ast.NotIn]:
+ if not isIterableType(eventualType(a.comparators[0])):
+ return True
+ elif eventualType(a.comparators[0]) in [str, bytes] and eventualType(a.left) not in [str, bytes]:
+ return True
+ elif type(a.ops[0]) in [ast.Lt, ast.LtE, ast.Gt, ast.GtE]:
+ # In Python3, you can't compare different types. BOOOOOO!!
+ firstType = eventualType(a.left)
+ if firstType == None:
+ return True
+ for comp in a.comparators:
+ if eventualType(comp) != firstType:
+ return True
+ elif type(a) == ast.Call:
+ env = [] # TODO: what if the environments aren't imported?
+ # First, gather up the needed variables
+ if type(a.func) == ast.Name:
+ funName = a.func.id
+ if funName not in builtInSafeFunctions:
+ return True
+ funDict = builtInFunctions
+ elif type(a.func) == ast.Attribute:
+ if type(a.func.value) == a.Name and \
+ (not hasattr(a.func.value, "varID")) and \
+ a.func.value.id in supportedLibraries:
+ funName = a.func.attr
+ if funName not in safeLibraryMap(a.func.value.id):
+ return True
+ funDict = libraryMap[a.func.value.id]
+ elif eventualType(a.func.value) == str:
+ funName = a.func.attr
+ if funName not in safeStringFunctions:
+ return True
+ funDict = builtInStringFunctions
+ else: # list and dict are definitely crashable
+ return True
+ else:
+ return True
+
+ if funName in ["max", "min"]:
+ return False # Special functions that have infinite args
+
+ # First, load up the arg types
+ argTypes = []
+ for i in range(len(a.args)):
+ eventual = eventualType(a.args[i])
+ if (eventual == None and typeCrashes):
+ return True
+ argTypes.append(eventual)
+
+ if funDict[funName] != None:
+ for argSet in funDict[funName]: # the given possibilities of arg types
+ if len(argSet) != len(argTypes):
+ continue
+ if not typeCrashes: # If we don't care about types, stop now
+ return False
+
+ for i in range(len(argSet)):
+ if not (argSet[i] == argTypes[i] or issubclass(argTypes[i], argSet[i])):
+ break
+ else: # if all types matched
+ return False
+ return True # Didn't fit any of the options
+ elif type(a) == ast.Subscript: # can only get an index from a string or list
+ return eventualType(a.value) not in [str, list, tuple]
+ elif type(a) == ast.Name:
+ # If it's an undefined variable, it might crash
+ if hasattr(a, "randomVar"):
+ return True
+ elif type(a) == ast.Slice:
+ if a.lower != None and eventualType(a.lower) != int:
+ return True
+ if a.upper != None and eventualType(a.upper) != int:
+ return True
+ if a.step != None and eventualType(a.step) != int:
+ return True
+ elif type(a) in [ast.Raise, ast.Assert, ast.Pass, ast.Break, \
+ ast.Continue, ast.Yield, ast.Attribute, ast.ExtSlice, ast.Index, \
+ ast.Starred]:
+ # All of these cases can definitely crash.
+ return True
+ return False
+
+def eventualType(a):
+ """Get the type the expression will eventually be, if possible
+ The expression might also crash! But we don't care about that here,
+ we'll deal with it elsewhere.
+ Returning 'None' means that we cannot say at the moment"""
+ if type(a) in builtInTypes:
+ return type(a)
+ if not isinstance(a, ast.AST):
+ return None
+
+ elif type(a) == ast.BoolOp:
+ # In Python, it's the type of all the values in it
+ # this may work differently in other languages
+ t = eventualType(a.values[0])
+ for i in range(1, len(a.values)):
+ if eventualType(a.values[i]) != t:
+ return None
+ return t
+ elif type(a) == ast.BinOp:
+ l = eventualType(a.left)
+ r = eventualType(a.right)
+ # It is possible to add/multiply sequences
+ if type(a.op) in [ast.Add, ast.Mult]:
+ if isIterableType(l):
+ return l
+ elif isIterableType(r):
+ return r
+ elif l == float or r == float:
+ return float
+ elif l == int and r == int:
+ return int
+ return None
+ elif type(a.op) == ast.Div:
+ return float # always a float now
+ # For others, check if we know whether it's a float or an int
+ elif type(a.op) in [ast.FloorDiv, ast.LShift, ast.RShift, ast.BitOr,
+ ast.BitAnd, ast.BitXor]:
+ return int
+ elif float in [l, r]:
+ return float
+ elif l == int and r == int:
+ return int
+ else:
+ return None # Otherwise, it could be a float- we don't know
+ elif type(a) == ast.UnaryOp:
+ if type(a.op) == ast.Invert:
+ return int
+ elif type(a.op) in [ast.UAdd, ast.USub]:
+ return eventualType(a.operand)
+ else: # Not op
+ return bool
+ elif type(a) == ast.Lambda:
+ return function
+ elif type(a) == ast.IfExp:
+ l = eventualType(a.body)
+ r = eventualType(a.orelse)
+ if l == r:
+ return l
+ else:
+ return None
+ elif type(a) in [ast.Dict, ast.DictComp]:
+ return dict
+ elif type(a) in [ast.Set, ast.SetComp]:
+ return set
+ elif type(a) in [ast.List, ast.ListComp]:
+ return list
+ elif type(a) == ast.GeneratorExp:
+ return None # can't represent a generator
+ elif type(a) == ast.Yield:
+ return None # we don't know
+ elif type(a) == ast.Compare:
+ return bool
+ elif type(a) == ast.Call:
+ # Go through our different sets of known functions to see if we know the type
+ argTypes = [eventualType(x) for x in a.args]
+ if type(a.func) == ast.Name:
+ funDict = builtInFunctions
+ funName = a.func.id
+ elif type(a.func) == ast.Attribute:
+ # TODO: get a better solution than this
+ funName = a.func.attr
+ if type(a.func.value) == ast.Name and \
+ (not hasattr(a.func.value, "varID")) and \
+ a.func.value.id in supportedLibraries:
+ funDict = libraryDictMap[a.func.value.id]
+ if a.func.value.id in ["string", "str", "list", "dict"] and len(argTypes) > 0:
+ argTypes.pop(0) # get rid of the first string arg
+ elif eventualType(a.func.value) == str:
+ funDict = builtInStringFunctions
+ elif eventualType(a.func.value) == list:
+ funDict = builtInListFunctions
+ elif eventualType(a.func.value) == dict:
+ funDict = builtInDictFunctions
+ else:
+ return None
+ else:
+ return None
+
+ if funName in ["max", "min"]:
+ # If all args are the same type, that's our type
+ uniqueTypes = set(argTypes)
+ if len(uniqueTypes) == 1:
+ return uniqueTypes.pop()
+ return None
+
+ if funName in funDict and funDict[funName] != None:
+ possibleTypes = []
+ for argSet in funDict[funName]:
+ if len(argSet) == len(argTypes):
+ # All types must match!
+ for i in range(len(argSet)):
+ if argSet[i] == None or argTypes[i] == None: # We don't know, but that's okay
+ continue
+ if not (argSet[i] == argTypes[i] or (issubclass(argTypes[i], argSet[i]))):
+ break
+ else:
+ possibleTypes.append(funDict[funName][argSet])
+ possibleTypes = set(possibleTypes)
+ if len(possibleTypes) == 1: # If there's only one possibility, that's our type!
+ return possibleTypes.pop()
+ return None
+ elif type(a) in [ast.Str, ast.Bytes]:
+ if containsTokenStepString(a):
+ return None
+ return str
+ elif type(a) == ast.Num:
+ return type(a.n)
+ elif type(a) == ast.Attribute:
+ return None # we have no way of knowing
+ elif type(a) == ast.Subscript:
+ # We're slicing the object, so the type will stay the same
+ t = eventualType(a.value)
+ if t == None:
+ return None
+ elif t == str:
+ return str # indexing a string
+ elif t in [list, tuple]:
+ if type(a.slice) == ast.Slice:
+ return t
+ # Otherwise, we need the types of the elements
+ if type(a.value) in [ast.List, ast.Tuple]:
+ if len(a.value.elts) == 0:
+ return None # We don't know
+ else:
+ eltType = eventualType(a.value.elts[0])
+ for elt in a.value.elts:
+ if eventualType(elt) != eltType:
+ return None # Disagreement!
+ return eltType
+ elif t in [dict, int]:
+ return None
+ else:
+ log("astTools\teventualType\tUnknown type in subscript: " + str(t), "bug")
+ return None # We can't know for now...
+ elif type(a) == ast.NameConstant:
+ if a.value == True or a.value == False:
+ return bool
+ elif a.value == None:
+ return type(None)
+ return None
+ elif type(a) == ast.Name:
+ if hasattr(a, "type"): # If it's a variable we categorized
+ return a.type
+ return None
+ elif type(a) == ast.Tuple:
+ return tuple
+ elif type(a) == ast.Starred:
+ return None # too complicated
+ else:
+ log("astTools\teventualType\tUnimplemented type " + str(type(a)), "bug")
+ return None
+
+def depthOfAST(a):
+ """Determine the depth of the AST"""
+ if not isinstance(a, ast.AST):
+ return 0
+ m = 0
+ for child in ast.iter_child_nodes(a):
+ tmp = depthOfAST(child)
+ if tmp > m:
+ m = tmp
+ return m + 1
+
+def compareASTs(a, b, checkEquality=False):
+ """A comparison function for ASTs"""
+ # None before others
+ if a == b == None:
+ return 0
+ elif a == None or b == None:
+ return -1 if a == None else 1
+
+ if type(a) == type(b) == list:
+ if len(a) != len(b):
+ return len(a) - len(b)
+ for i in range(len(a)):
+ r = compareASTs(a[i], b[i], checkEquality=checkEquality)
+ if r != 0:
+ return r
+ return 0
+
+ # AST before primitive
+ if (not isinstance(a, ast.AST)) and (not isinstance(b, ast.AST)):
+ if type(a) != type(b):
+ builtins = [bool, int, float, str, bytes, complex]
+ if type(a) not in builtins or type(b) not in builtins:
+ log("MISSING BUILT-IN TYPE: " + str(type(a)) + "," + str(type(b)), "bug")
+ return builtins.index(type(a)) - builtins.index(type(b))
+ return cmp(a, b)
+ elif (not isinstance(a, ast.AST)) or (not isinstance(b, ast.AST)):
+ return -1 if isinstance(a, ast.AST) else 1
+
+ # Order by differing types
+ if type(a) != type(b):
+ # Here is a brief ordering of types that we care about
+ blehTypes = [ ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param ]
+ if type(a) in blehTypes and type(b) in blehTypes:
+ return 0
+ elif type(a) in blehTypes or type(b) in blehTypes:
+ return -1 if type(a) in blehTypes else 1
+
+ types = [ ast.Module, ast.Interactive, ast.Expression, ast.Suite,
+
+ ast.Break, ast.Continue, ast.Pass, ast.Global,
+ ast.Expr, ast.Assign, ast.AugAssign, ast.Return,
+ ast.Assert, ast.Delete, ast.If, ast.For, ast.While,
+ ast.With, ast.Import, ast.ImportFrom, ast.Raise,
+ ast.Try, ast.FunctionDef,
+ ast.ClassDef,
+
+ ast.BinOp, ast.BoolOp, ast.Compare, ast.UnaryOp,
+ ast.DictComp, ast.ListComp, ast.SetComp, ast.GeneratorExp,
+ ast.Yield, ast.Lambda, ast.IfExp, ast.Call, ast.Subscript,
+ ast.Attribute, ast.Dict, ast.List, ast.Tuple,
+ ast.Set, ast.Name, ast.Str, ast.Bytes, ast.Num,
+ ast.NameConstant, ast.Starred,
+
+ ast.Ellipsis, ast.Index, ast.Slice, ast.ExtSlice,
+
+ ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div,
+ ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr,
+ ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert, ast.Not,
+ ast.UAdd, ast.USub, ast.Eq, ast.NotEq, ast.Lt, ast.LtE,
+ ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, ast.NotIn,
+
+ ast.alias, ast.keyword, ast.arguments, ast.arg, ast.comprehension,
+ ast.ExceptHandler, ast.withitem
+ ]
+ if (type(a) not in types) or (type(b) not in types):
+ log("astTools\tcompareASTs\tmissing type:" + str(type(a)) + "," + str(type(b)), "bug")
+ return 0
+ return types.index(type(a)) - types.index(type(b))
+
+ # Then, more complex expressions- but don't bother with this if we're just checking equality
+ if not checkEquality:
+ ad = depthOfAST(a)
+ bd = depthOfAST(b)
+ if ad != bd:
+ return bd - ad
+
+ # NameConstants are special
+ if type(a) == ast.NameConstant:
+ if a.value == None or b.value == None:
+ return 1 if a.value != None else (0 if b.value == None else -1) # short and works
+
+ if a.value in [True, False] or b.value in [True, False]:
+ return 1 if a.value not in [True, False] else (cmp(a.value, b.value) if b.value in [True, False] else -1)
+
+ if type(a) == ast.Name:
+ return cmp(a.id, b.id)
+
+ # Operations and attributes are all ok
+ elif type(a) in [ ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div,
+ ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr,
+ ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert,
+ ast.Not, ast.UAdd, ast.USub, ast.Eq, ast.NotEq, ast.Lt,
+ ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In,
+ ast.NotIn, ast.Load, ast.Store, ast.Del, ast.AugLoad,
+ ast.AugStore, ast.Param, ast.Ellipsis, ast.Pass,
+ ast.Break, ast.Continue
+ ]:
+ return 0
+
+ # Now compare based on the attributes in the identical types
+ for attr in a._fields:
+ r = compareASTs(getattr(a, attr), getattr(b, attr), checkEquality=checkEquality)
+ if r != 0:
+ return r
+ # If all attributes are identical, they're equal
+ return 0
+
+def deepcopyList(l):
+ """Deepcopy of a list"""
+ if l == None:
+ return None
+ if isinstance(l, ast.AST):
+ return deepcopy(l)
+ if type(l) != list:
+ log("astTools\tdeepcopyList\tNot a list: " + str(type(l)), "bug")
+ return copy.deepcopy(l)
+
+ newList = []
+ for line in l:
+ newList.append(deepcopy(line))
+ return newList
+
+def deepcopy(a):
+ """Let's try to keep this as quick as possible"""
+ if a == None:
+ return None
+ if type(a) == list:
+ return deepcopyList(a)
+ elif type(a) in [int, float, str, bool]:
+ return a
+ if not isinstance(a, ast.AST):
+ log("astTools\tdeepcopy\tNot an AST: " + str(type(a)), "bug")
+ return copy.deepcopy(a)
+
+ g = a.global_id if hasattr(a, "global_id") else None
+ cp = None
+ # Objects without lineno, col_offset
+ if type(a) in [ ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div,
+ ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr,
+ ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert,
+ ast.Not, ast.UAdd, ast.USub, ast.Eq, ast.NotEq, ast.Lt,
+ ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In,
+ ast.NotIn, ast.Load, ast.Store, ast.Del, ast.AugLoad,
+ ast.AugStore, ast.Param
+ ]:
+ return a
+ elif type(a) == ast.Module:
+ cp = ast.Module(deepcopyList(a.body))
+ elif type(a) == ast.Interactive:
+ cp = ast.Interactive(deepcopyList(a.body))
+ elif type(a) == ast.Expression:
+ cp = ast.Expression(deepcopy(a.body))
+ elif type(a) == ast.Suite:
+ cp = ast.Suite(deepcopyList(a.body))
+
+ elif type(a) == ast.FunctionDef:
+ cp = ast.FunctionDef(a.name, deepcopy(a.args), deepcopyList(a.body),
+ deepcopyList(a.decorator_list), deepcopy(a.returns))
+ elif type(a) == ast.ClassDef:
+ cp = ast.ClassDef(a.name, deepcopyList(a.bases), deepcopyList(a.keywords), deepcopyList(a.body),
+ deepcopyList(a.decorator_list))
+ elif type(a) == ast.Return:
+ cp = ast.Return(deepcopy(a.value))
+ elif type(a) == ast.Delete:
+ cp = ast.Delete(deepcopyList(a.targets))
+ elif type(a) == ast.Assign:
+ cp = ast.Assign(deepcopyList(a.targets), deepcopy(a.value))
+ elif type(a) == ast.AugAssign:
+ cp = ast.AugAssign(deepcopy(a.target), deepcopy(a.op),
+ deepcopy(a.value))
+ elif type(a) == ast.For:
+ cp = ast.For(deepcopy(a.target), deepcopy(a.iter),
+ deepcopyList(a.body), deepcopyList(a.orelse))
+ elif type(a) == ast.While:
+ cp = ast.While(deepcopy(a.test), deepcopyList(a.body),
+ deepcopyList(a.orelse))
+ elif type(a) == ast.If:
+ cp = ast.If(deepcopy(a.test), deepcopyList(a.body),
+ deepcopyList(a.orelse))
+ elif type(a) == ast.With:
+ cp = ast.With(deepcopyList(a.items),deepcopyList(a.body))
+ elif type(a) == ast.Raise:
+ cp = ast.Raise(deepcopy(a.exc), deepcopy(a.cause))
+ elif type(a) == ast.Try:
+ cp = ast.Try(deepcopyList(a.body), deepcopyList(a.handlers),
+ deepcopyList(a.orelse), deepcopyList(a.finalbody))
+ elif type(a) == ast.Assert:
+ cp = ast.Assert(deepcopy(a.test), deepcopy(a.msg))
+ elif type(a) == ast.Import:
+ cp = ast.Import(deepcopyList(a.names))
+ elif type(a) == ast.ImportFrom:
+ cp = ast.ImportFrom(a.module, deepcopyList(a.names), a.level)
+ elif type(a) == ast.Global:
+ cp = ast.Global(a.names[:])
+ elif type(a) == ast.Expr:
+ cp = ast.Expr(deepcopy(a.value))
+ elif type(a) == ast.Pass:
+ cp = ast.Pass()
+ elif type(a) == ast.Break:
+ cp = ast.Break()
+ elif type(a) == ast.Continue:
+ cp = ast.Continue()
+
+ elif type(a) == ast.BoolOp:
+ cp = ast.BoolOp(a.op, deepcopyList(a.values))
+ elif type(a) == ast.BinOp:
+ cp = ast.BinOp(deepcopy(a.left), a.op, deepcopy(a.right))
+ elif type(a) == ast.UnaryOp:
+ cp = ast.UnaryOp(a.op, deepcopy(a.operand))
+ elif type(a) == ast.Lambda:
+ cp = ast.Lambda(deepcopy(a.args), deepcopy(a.body))
+ elif type(a) == ast.IfExp:
+ cp = ast.IfExp(deepcopy(a.test), deepcopy(a.body), deepcopy(a.orelse))
+ elif type(a) == ast.Dict:
+ cp = ast.Dict(deepcopyList(a.keys), deepcopyList(a.values))
+ elif type(a) == ast.Set:
+ cp = ast.Set(deepcopyList(a.elts))
+ elif type(a) == ast.ListComp:
+ cp = ast.ListComp(deepcopy(a.elt), deepcopyList(a.generators))
+ elif type(a) == ast.SetComp:
+ cp = ast.SetComp(deepcopy(a.elt), deepcopyList(a.generators))
+ elif type(a) == ast.DictComp:
+ cp = ast.DictComp(deepcopy(a.key), deepcopy(a.value),
+ deepcopyList(a.generators))
+ elif type(a) == ast.GeneratorExp:
+ cp = ast.GeneratorExp(deepcopy(a.elt), deepcopyList(a.generators))
+ elif type(a) == ast.Yield:
+ cp = ast.Yield(deepcopy(a.value))
+ elif type(a) == ast.Compare:
+ cp = ast.Compare(deepcopy(a.left), a.ops[:],
+ deepcopyList(a.comparators))
+ elif type(a) == ast.Call:
+ cp = ast.Call(deepcopy(a.func), deepcopyList(a.args), deepcopyList(a.keywords))
+ elif type(a) == ast.Num:
+ cp = ast.Num(a.n)
+ elif type(a) == ast.Str:
+ cp = ast.Str(a.s)
+ elif type(a) == ast.Bytes:
+ cp = ast.Bytes(a.s)
+ elif type(a) == ast.NameConstant:
+ cp = ast.NameConstant(a.value)
+ elif type(a) == ast.Attribute:
+ cp = ast.Attribute(deepcopy(a.value), a.attr, a.ctx)
+ elif type(a) == ast.Subscript:
+ cp = ast.Subscript(deepcopy(a.value), deepcopy(a.slice), a.ctx)
+ elif type(a) == ast.Name:
+ cp = ast.Name(a.id, a.ctx)
+ elif type(a) == ast.List:
+ cp = ast.List(deepcopyList(a.elts), a.ctx)
+ elif type(a) == ast.Tuple:
+ cp = ast.Tuple(deepcopyList(a.elts), a.ctx)
+ elif type(a) == ast.Starred:
+ cp = ast.Starred(deepcopy(a.value), a.ctx)
+
+ elif type(a) == ast.Slice:
+ cp = ast.Slice(deepcopy(a.lower), deepcopy(a.upper), deepcopy(a.step))
+ elif type(a) == ast.ExtSlice:
+ cp = ast.ExtSlice(deepcopyList(a.dims))
+ elif type(a) == ast.Index:
+ cp = ast.Index(deepcopy(a.value))
+
+ elif type(a) == ast.comprehension:
+ cp = ast.comprehension(deepcopy(a.target), deepcopy(a.iter),
+ deepcopyList(a.ifs))
+ elif type(a) == ast.ExceptHandler:
+ cp = ast.ExceptHandler(deepcopy(a.type), a.name, deepcopyList(a.body))
+ elif type(a) == ast.arguments:
+ cp = ast.arguments(deepcopyList(a.args), deepcopy(a.vararg),
+ deepcopyList(a.kwonlyargs), deepcopyList(a.kw_defaults),
+ deepcopy(a.kwarg), deepcopyList(a.defaults))
+ elif type(a) == ast.arg:
+ cp = ast.arg(a.arg, deepcopy(a.annotation))
+ elif type(a) == ast.keyword:
+ cp = ast.keyword(a.arg, deepcopy(a.value))
+ elif type(a) == ast.alias:
+ cp = ast.alias(a.name, a.asname)
+ elif type(a) == ast.withitem:
+ cp = ast.withitem(deepcopy(a.context_expr), deepcopy(a.optional_vars))
+ else:
+ log("astTools\tdeepcopy\tNot implemented: " + str(type(a)), "bug")
+ cp = copy.deepcopy(a)
+
+ transferMetaData(a, cp)
+ return cp
+
+def exportToJson(a):
+ """Export the ast to json format"""
+ if a == None:
+ return "null"
+ elif type(a) in [int, float]:
+ return str(a)
+ elif type(a) == str:
+ return '"' + a + '"'
+ elif not isinstance(a, ast.AST):
+ log("astTools\texportToJson\tMissing type: " + str(type(a)), "bug")
+
+ s = "{\n"
+ if type(a) in astNames:
+ s += '"' + astNames[type(a)] + '": {\n'
+ for field in a._fields:
+ s += '"' + field + '": '
+ value = getattr(a, field)
+ if type(value) == list:
+ s += "["
+ for item in value:
+ s += exportToJson(item) + ", "
+ if len(value) > 0:
+ s = s[:-2]
+ s += "]"
+ else:
+ s += exportToJson(value)
+ s += ", "
+ if len(a._fields) > 0:
+ s = s[:-2]
+ s += "}"
+ else:
+ log("astTools\texportToJson\tMissing AST type: " + str(type(a)), "bug")
+ s += "}"
+ return s
+
+### ITAP/Canonicalization Functions ###
+
+def isTokenStepString(s):
+ """Determine whether this is a placeholder string"""
+ if len(s) < 2:
+ return False
+ return s[0] == "~" and s[-1] == "~"
+
+def getParentFunction(s):
+ underscoreSep = s.split("_")
+ if len(underscoreSep) == 1:
+ return None
+ result = "_".join(underscoreSep[1:])
+ if result == "newvar" or result == "global":
+ return None
+ return result
+
+def isAnonVariable(s):
+ """Specificies whether the given string is an anonymized variable name"""
+ preUnderscore = s.split("_")[0] # the part before the function name
+ return len(preUnderscore) > 1 and \
+ preUnderscore[0] in ["g", "p", "v", "r", "n", "z"] and \
+ preUnderscore[1:].isdigit()
+
+def isDefault(a):
+ """Our programs have a default setting of return 42, so we should detect that"""
+ if type(a) == ast.Module and len(a.body) == 1:
+ a = a.body[0]
+ else:
+ return False
+
+ if type(a) != ast.FunctionDef:
+ return False
+
+ if len(a.body) == 0:
+ return True
+ elif len(a.body) == 1:
+ if type(a.body[0]) == ast.Return:
+ if a.body[0].value == None or \
+ type(a.body[0].value) == ast.Num and a.body[0].value.n == 42:
+ return True
+ return False
+
+def transferMetaData(a, b):
+ """Transfer the metadata of a onto b"""
+ properties = [ "global_id", "second_global_id", "lineno", "col_offset",
+ "originalId", "varID", "variableGlobalId",
+ "randomVar", "propagatedVariable", "loadedVariable", "dontChangeName",
+ "reversed", "negated", "inverted",
+ "augAssignVal", "augAssignBinOp",
+ "combinedConditional", "combinedConditionalOp",
+ "multiComp", "multiCompPart", "multiCompMiddle", "multiCompOp",
+ "addedNot", "addedNotOp", "addedOther", "addedOtherOp", "addedNeg",
+ "collapsedExpr", "removedLines",
+ "helperVar", "helperReturnVal", "helperParamAssign", "helperReturnAssign",
+ "orderedBinOp", "typeCastFunction", "moved_line" ]
+ for prop in properties:
+ if hasattr(a, prop):
+ setattr(b, prop, getattr(a, prop))
+
+def assignPropertyToAll(a, prop):
+ """Assign the provided property to all children"""
+ if type(a) == list:
+ for child in a:
+ assignPropertyToAll(child, prop)
+ elif isinstance(a, ast.AST):
+ for node in ast.walk(a):
+ setattr(node, prop, True)
+
+def removePropertyFromAll(a, prop):
+ if type(a) == list:
+ for child in a:
+ removePropertyFromAll(child, prop)
+ elif isinstance(a, ast.AST):
+ for node in ast.walk(a):
+ if hasattr(node, prop):
+ delattr(node, prop)
+
+def containsTokenStepString(a):
+ """This is used to keep token-level hint chaining from breaking."""
+ if not isinstance(a, ast.AST):
+ return False
+
+ for node in ast.walk(a):
+ if type(node) == ast.Str and isTokenStepString(node.s):
+ return True
+ return False
+
+def applyVariableMap(a, variableMap):
+ if not isinstance(a, ast.AST):
+ return a
+ if type(a) == ast.Name:
+ if a.id in variableMap:
+ a.id = variableMap[a.id]
+ elif type(a) in [ast.FunctionDef, ast.ClassDef]:
+ if a.name in variableMap:
+ a.name = variableMap[a.name]
+ return applyToChildren(a, lambda x : applyVariableMap(x, variableMap))
+
+def applyHelperMap(a, helperMap):
+ if not isinstance(a, ast.AST):
+ return a
+ if type(a) == ast.Name:
+ if a.id in helperMap:
+ a.id = helperMap[a.id]
+ elif type(a) == ast.FunctionDef:
+ if a.name in helperMap:
+ a.name = helperMap[a.name]
+ return applyToChildren(a, lambda x : applyHelperMap(x, helperMap))
+
+
+def astFormat(x, gid=None):
+ """Given a value, turn it into an AST if it's a constant; otherwise, leave it alone."""
+ if type(x) in [int, float, complex]:
+ return ast.Num(x)
+ elif type(x) == bool or x == None:
+ return ast.NameConstant(x)
+ elif type(x) == type:
+ types = { bool : "bool", int : "int", float : "float",
+ complex : "complex", str : "str", bytes : "bytes", unicode : "unicode",
+ list : "list", tuple : "tuple", dict : "dict" }
+ return ast.Name(types[x], ast.Load())
+ elif type(x) == str: # str or unicode
+ return ast.Str(x)
+ elif type(x) == bytes:
+ return ast.Bytes(x)
+ elif type(x) == list:
+ elts = [astFormat(val) for val in x]
+ return ast.List(elts, ast.Load())
+ elif type(x) == dict:
+ keys = []
+ vals = []
+ for key in x:
+ keys.append(astFormat(key))
+ vals.append(astFormat(x[key]))
+ return ast.Dict(keys, vals)
+ elif type(x) == tuple:
+ elts = [astFormat(val) for val in x]
+ return ast.Tuple(elts, ast.Load())
+ elif type(x) == set:
+ elts = [astFormat(val) for val in x]
+ if len(elts) == 0: # needs to be a call instead
+ return ast.Call(ast.Name("set", ast.Load()), [], [])
+ else:
+ return ast.Set(elts)
+ elif type(x) == slice:
+ return ast.Slice(astFormat(x.start), astFormat(x.stop), astFormat(x.step))
+ elif isinstance(x, ast.AST):
+ return x # Do not change if it's not constant!
+ else:
+ log("astTools\tastFormat\t" + str(type(x)) + "," + str(x),"bug")
+ return None
+
+def basicFormat(x):
+ """Given an AST, turn it into its value if it's constant; otherwise, leave it alone"""
+ if type(x) == ast.Num:
+ return x.n
+ elif type(x) == ast.NameConstant:
+ return x.value
+ elif type(x) == ast.Str:
+ return x.s
+ elif type(x) == ast.Bytes:
+ return x.s
+ return x # Do not change if it's not a constant!
+
+def structureTree(a):
+ if type(a) == list:
+ for i in range(len(a)):
+ a[i] = structureTree(a[i])
+ return a
+ elif not isinstance(a, ast.AST):
+ return a
+ else:
+ if type(a) == ast.FunctionDef:
+ a.name = "~name~"
+ a.args = structureTree(a.args)
+ a.body = structureTree(a.body)
+ a.decorator_list = structureTree(a.decorator_list)
+ a.returns = structureTree(a.returns)
+ elif type(a) == ast.ClassDef:
+ a.name = "~name~"
+ a.bases = structureTree(a.bases)
+ a.keywords = structureTree(a.keywords)
+ a.body = structureTree(a.body)
+ a.decorator_list = structureTree(a.decorator_list)
+ elif type(a) == ast.AugAssign:
+ a.target = structureTree(a.target)
+ a.op = ast.Str("~op~")
+ a.value = structureTree(a.value)
+ elif type(a) == ast.Import:
+ a.names = [ast.Str("~module~")]
+ elif type(a) == ast.ImportFrom:
+ a.module = "~module~"
+ a.names = [ast.Str("~names~")]
+ elif type(a) == ast.Global:
+ a.names = ast.Str("~var~")
+ elif type(a) == ast.BoolOp:
+ a.op = ast.Str("~op~")
+ a.values = structureTree(a.values)
+ elif type(a) == ast.BinOp:
+ a.op = ast.Str("~op~")
+ a.left = structureTree(a.left)
+ a.right = structureTree(a.right)
+ elif type(a) == ast.UnaryOp:
+ a.op = ast.Str("~op~")
+ a.operand = structureTree(a.operand)
+ elif type(a) == ast.Dict:
+ return ast.Str("~dictionary~")
+ elif type(a) == ast.Set:
+ return ast.Str("~set~")
+ elif type(a) == ast.Compare:
+ a.ops = [ast.Str("~op~")]*len(a.ops)
+ a.left = structureTree(a.left)
+ a.comparators = structureTree(a.comparators)
+ elif type(a) == ast.Call:
+ # leave the function alone
+ a.args = structureTree(a.args)
+ a.keywords = structureTree(a.keywords)
+ elif type(a) == ast.Num:
+ return ast.Str("~number~")
+ elif type(a) == ast.Str:
+ return ast.Str("~string~")
+ elif type(a) == ast.Bytes:
+ return ast.Str("~bytes~")
+ elif type(a) == ast.Attribute:
+ a.value = structureTree(a.value)
+ elif type(a) == ast.Name:
+ a.id = "~var~"
+ elif type(a) == ast.List:
+ return ast.Str("~list~")
+ elif type(a) == ast.Tuple:
+ return ast.Str("~tuple~")
+ elif type(a) in [ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div,
+ ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr,
+ ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert,
+ ast.Not, ast.UAdd, ast.USub, ast.Eq, ast.NotEq,
+ ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot,
+ ast.In, ast.NotIn ]:
+ return ast.Str("~op~")
+ elif type(a) == ast.arguments:
+ a.args = structureTree(a.args)
+ a.vararg = ast.Str("~arg~") if a.vararg != None else None
+ a.kwonlyargs = structureTree(a.kwonlyargs)
+ a.kw_defaults = structureTree(a.kw_defaults)
+ a.kwarg = ast.Str("~keyword~") if a.kwarg != None else None
+ a.defaults = structureTree(a.defaults)
+ elif type(a) == ast.arg:
+ a.arg = "~arg~"
+ a.annotation = structureTree(a.annotation)
+ elif type(a) == ast.keyword:
+ a.arg = "~keyword~"
+ a.value = structureTree(a.value)
+ elif type(a) == ast.alias:
+ a.name = "~name~"
+ a.asname = "~asname~" if a.asname != None else None
+ else:
+ for field in a._fields:
+ setattr(a, field, structureTree(getattr(a, field)))
+ return a
+
+
+
diff --git a/canonicalize/display.py b/canonicalize/display.py
new file mode 100644
index 0000000..174304f
--- /dev/null
+++ b/canonicalize/display.py
@@ -0,0 +1,570 @@
+import ast
+from .tools import log
+
+#===============================================================================
+# These functions are used for displaying ASTs. printAst displays the tree,
+# while printFunction displays the syntax
+#===============================================================================
+
+# TODO: add AsyncFunctionDef, AsyncFor, AsyncWith, AnnAssign, Nonlocal, Await, YieldFrom, FormattedValue, JoinedStr, Starred
+
+def printFunction(a, indent=0):
+ s = ""
+ if a == None:
+ return ""
+ if not isinstance(a, ast.AST):
+ log("display\tprintFunction\tNot AST: " + str(type(a)) + "," + str(a), "bug")
+ return str(a)
+
+ t = type(a)
+ if t in [ast.Module, ast.Interactive, ast.Suite]:
+ for line in a.body:
+ s += printFunction(line, indent)
+ elif t == ast.Expression:
+ s += printFunction(a.body, indent)
+ elif t == ast.FunctionDef:
+ for dec in a.decorator_list:
+ s += (indent * 4 * " ") + "@" + printFunction(dec, indent) + "\n"
+ s += (indent * 4 * " ") + "def " + a.name + "(" + \
+ printFunction(a.args, indent) + "):\n"
+ for stmt in a.body:
+ s += printFunction(stmt, indent+1)
+ # TODO: returns
+ elif t == ast.ClassDef:
+ for dec in a.decorator_list:
+ s += (indent * 4 * " ") + "@" + printFunction(dec, indent) + "\n"
+ s += (indent * 4 * " ") + "class " + a.name
+ if len(a.bases) > 0 or len(a.keywords) > 0:
+ s += "("
+ for base in a.bases:
+ s += printFunction(base, indent) + ", "
+ for keyword in a.keywords:
+ s += printFunction(keyword, indent) + ", "
+ s += s[:-2] + ")"
+ s += ":\n"
+ for stmt in a.body:
+ s += printFunction(stmt, indent+1)
+ elif t == ast.Return:
+ s += (indent * 4 * " ") + "return " + \
+ printFunction(a.value, indent) + "\n"
+ elif t == ast.Delete:
+ s += (indent * 4 * " ") + "del "
+ for target in a.targets:
+ s += printFunction(target, indent) + ", "
+ if len(a.targets) >= 1:
+ s = s[:-2]
+ s += "\n"
+ elif t == ast.Assign:
+ s += (indent * 4 * " ")
+ for target in a.targets:
+ s += printFunction(target, indent) + " = "
+ s += printFunction(a.value, indent) + "\n"
+ elif t == ast.AugAssign:
+ s += (indent * 4 * " ")
+ s += printFunction(a.target, indent) + " " + \
+ printFunction(a.op, indent) + "= " + \
+ printFunction(a.value, indent) + "\n"
+ elif t == ast.For:
+ s += (indent * 4 * " ")
+ s += "for " + \
+ printFunction(a.target, indent) + " in " + \
+ printFunction(a.iter, indent) + ":\n"
+ for line in a.body:
+ s += printFunction(line, indent + 1)
+ if len(a.orelse) > 0:
+ s += (indent * 4 * " ")
+ s += "else:\n"
+ for line in a.orelse:
+ s += printFunction(line, indent + 1)
+ elif t == ast.While:
+ s += (indent * 4 * " ")
+ s += "while " + printFunction(a.test, indent) + ":\n"
+ for line in a.body:
+ s += printFunction(line, indent + 1)
+ if len(a.orelse) > 0:
+ s += (indent * 4 * " ")
+ s += "else:\n"
+ for line in a.orelse:
+ s += printFunction(line, indent + 1)
+ elif t == ast.If:
+ s += (indent * 4 * " ")
+ s += "if " + printFunction(a.test, indent) + ":\n"
+ for line in a.body:
+ s += printFunction(line, indent + 1)
+ branch = a.orelse
+ # elifs
+ while len(branch) == 1 and type(branch[0]) == ast.If:
+ s += (indent * 4 * " ")
+ s += "elif " + printFunction(branch[0].test, indent) + ":\n"
+ for line in branch[0].body:
+ s += printFunction(line, indent + 1)
+ branch = branch[0].orelse
+ if len(branch) > 0:
+ s += (indent * 4 * " ")
+ s += "else:\n"
+ for line in branch:
+ s += printFunction(line, indent + 1)
+ elif t == ast.With:
+ s += (indent * 4 * " ")
+ s += "with "
+ for item in a.items:
+ s += printFunction(item, indent) + ", "
+ if len(a.items) > 0:
+ s = s[:-2]
+ s += ":\n"
+ for line in a.body:
+ s += printFunction(line, indent + 1)
+ elif t == ast.Raise:
+ s += (indent * 4 * " ")
+ s += "raise"
+ if a.exc != None:
+ s += " " + printFunction(a.exc, indent)
+ # TODO: what is cause?!?
+ s += "\n"
+ elif type(a) == ast.Try:
+ s += (indent * 4 * " ") + "try:\n"
+ for line in a.body:
+ s += printFunction(line, indent + 1)
+ for handler in a.handlers:
+ s += printFunction(handler, indent)
+ if len(a.orelse) > 0:
+ s += (indent * 4 * " ") + "else:\n"
+ for line in a.orelse:
+ s += printFunction(line, indent + 1)
+ if len(a.finalbody) > 0:
+ s += (indent * 4 * " ") + "finally:\n"
+ for line in a.finalbody:
+ s += printFunction(line, indent + 1)
+ elif t == ast.Assert:
+ s += (indent * 4 * " ")
+ s += "assert " + printFunction(a.test, indent)
+ if a.msg != None:
+ s += ", " + printFunction(a.msg, indent)
+ s += "\n"
+ elif t == ast.Import:
+ s += (indent * 4 * " ") + "import "
+ for n in a.names:
+ s += printFunction(n, indent) + ", "
+ if len(a.names) > 0:
+ s = s[:-2]
+ s += "\n"
+ elif t == ast.ImportFrom:
+ s += (indent * 4 * " ") + "from "
+ s += ("." * a.level if a.level != None else "") + a.module + " import "
+ for name in a.names:
+ s += printFunction(name, indent) + ", "
+ if len(a.names) > 0:
+ s = s[:-2]
+ s += "\n"
+ elif t == ast.Global:
+ s += (indent * 4 * " ") + "global "
+ for name in a.names:
+ s += name + ", "
+ s = s[:-2] + "\n"
+ elif t == ast.Expr:
+ s += (indent * 4 * " ") + printFunction(a.value, indent) + "\n"
+ elif t == ast.Pass:
+ s += (indent * 4 * " ") + "pass\n"
+ elif t == ast.Break:
+ s += (indent * 4 * " ") + "break\n"
+ elif t == ast.Continue:
+ s += (indent * 4 * " ") + "continue\n"
+
+ elif t == ast.BoolOp:
+ s += "(" + printFunction(a.values[0], indent)
+ for i in range(1, len(a.values)):
+ s += " " + printFunction(a.op, indent) + " " + \
+ printFunction(a.values[i], indent)
+ s += ")"
+ elif t == ast.BinOp:
+ s += "(" + printFunction(a.left, indent)
+ s += " " + printFunction(a.op, indent) + " "
+ s += printFunction(a.right, indent) + ")"
+ elif t == ast.UnaryOp:
+ s += "(" + printFunction(a.op, indent) + " "
+ s += printFunction(a.operand, indent) + ")"
+ elif t == ast.Lambda:
+ s += "lambda "
+ s += printFunction(a.arguments, indent) + ": "
+ s += printFunction(a.body, indent)
+ elif t == ast.IfExp:
+ s += "(" + printFunction(a.body, indent)
+ s += " if " + printFunction(a.test, indent)
+ s += " else " + printFunction(a.orelse, indent) + ")"
+ elif t == ast.Dict:
+ s += "{ "
+ for i in range(len(a.keys)):
+ s += printFunction(a.keys[i], indent)
+ s += " : "
+ s += printFunction(a.values[i], indent)
+ s += ", "
+ if len(a.keys) >= 1:
+ s = s[:-2]
+ s += " }"
+ elif t == ast.Set:
+ # Empty sets must be initialized in a special way
+ if len(a.elts) == 0:
+ s += "set()"
+ else:
+ s += "{"
+ for elt in a.elts:
+ s += printFunction(elt, indent) + ", "
+ s = s[:-2]
+ s += "}"
+ elif t == ast.ListComp:
+ s += "["
+ s += printFunction(a.elt, indent) + " "
+ for gen in a.generators:
+ s += printFunction(gen, indent) + " "
+ s = s[:-1]
+ s += "]"
+ elif t == ast.SetComp:
+ s += "{"
+ s += printFunction(a.elt, indent) + " "
+ for gen in a.generators:
+ s += printFunction(gen, indent) + " "
+ s = s[:-1]
+ s += "}"
+ elif t == ast.DictComp:
+ s += "{"
+ s += printFunction(a.key, indent) + " : " + \
+ printFunction(a.value, indent) + " "
+ for gen in a.generators:
+ s += printFunction(gen, indent) + " "
+ s = s[:-1]
+ s += "}"
+ elif t == ast.GeneratorExp:
+ s += "("
+ s += printFunction(a.elt, indent) + " "
+ for gen in a.generators:
+ s += printFunction(gen, indent) + " "
+ s = s[:-1]
+ s += ")"
+ elif t == ast.Yield:
+ s += "yield " + printFunction(a.value, indent)
+ elif t == ast.Compare:
+ s += "(" + printFunction(a.left, indent)
+ for i in range(len(a.ops)):
+ s += " " + printFunction(a.ops[i], indent)
+ if i < len(a.comparators):
+ s += " " + printFunction(a.comparators[i], indent)
+ if len(a.comparators) > len(a.ops):
+ for i in range(len(a.ops), len(a.comparators)):
+ s += " " + printFunction(a.comparators[i], indent)
+ s += ")"
+ elif t == ast.Call:
+ s += printFunction(a.func, indent) + "("
+ for arg in a.args:
+ s += printFunction(arg, indent) + ", "
+ for key in a.keywords:
+ s += printFunction(key, indent) + ", "
+ if len(a.args) + len(a.keywords) >= 1:
+ s = s[:-2]
+ s += ")"
+ elif t == ast.Num:
+ if a.n != None:
+ if (type(a.n) == complex) or (type(a.n) != complex and a.n < 0):
+ s += '(' + str(a.n) + ')'
+ else:
+ s += str(a.n)
+ elif t == ast.Str:
+ if a.s != None:
+ val = repr(a.s)
+ if val[0] == '"': # There must be a single quote in there...
+ val = "'''" + val[1:len(val)-1] + "'''"
+ s += val
+ #s += "'" + a.s.replace("'", "\\'").replace('"', "\\'").replace("\n","\\n") + "'"
+ elif t == ast.Bytes:
+ s += str(a.s)
+ elif t == ast.NameConstant:
+ s += str(a.value)
+ elif t == ast.Attribute:
+ s += printFunction(a.value, indent) + "." + str(a.attr)
+ elif t == ast.Subscript:
+ s += printFunction(a.value, indent) + "[" + printFunction(a.slice, indent) + "]"
+ elif t == ast.Name:
+ s += a.id
+ elif t == ast.List:
+ s += "["
+ for elt in a.elts:
+ s += printFunction(elt, indent) + ", "
+ if len(a.elts) >= 1:
+ s = s[:-2]
+ s += "]"
+ elif t == ast.Tuple:
+ s += "("
+ for elt in a.elts:
+ s += printFunction(elt, indent) + ", "
+ if len(a.elts) > 1:
+ s = s[:-2]
+ elif len(a.elts) == 1:
+ s = s[:-1] # don't get rid of the comma! It clarifies that this is a tuple
+ s += ")"
+ elif t == ast.Starred:
+ s += "*" + printFunction(a.value, indent)
+ elif t == ast.Ellipsis:
+ s += "..."
+ elif t == ast.Slice:
+ if a.lower != None:
+ s += printFunction(a.lower, indent)
+ s += ":"
+ if a.upper != None:
+ s += printFunction(a.upper, indent)
+ if a.step != None:
+ s += ":" + printFunction(a.step, indent)
+ elif t == ast.ExtSlice:
+ for dim in a.dims:
+ s += printFunction(dim, indent) + ", "
+ if len(a.dims) > 0:
+ s = s[:-2]
+ elif t == ast.Index:
+ s += printFunction(a.value, indent)
+
+ elif t == ast.comprehension:
+ s += "for "
+ s += printFunction(a.target, indent) + " "
+ s += "in "
+ s += printFunction(a.iter, indent) + " "
+ for cond in a.ifs:
+ s += "if "
+ s += printFunction(cond, indent) + " "
+ s = s[:-1]
+ elif t == ast.ExceptHandler:
+ s += (indent * 4 * " ") + "except"
+ if a.type != None:
+ s += " " + printFunction(a.type, indent)
+ if a.name != None:
+ s += " as " + a.name
+ s += ":\n"
+ for line in a.body:
+ s += printFunction(line, indent + 1)
+ elif t == ast.arguments:
+ # Defaults are only applied AFTER non-defaults
+ defaultStart = len(a.args) - len(a.defaults)
+ for i in range(len(a.args)):
+ s += printFunction(a.args[i], indent)
+ if i >= defaultStart:
+ s += "=" + printFunction(a.defaults[i - defaultStart], indent)
+ s += ", "
+ if a.vararg != None:
+ s += "*" + printFunction(a.vararg, indent) + ", "
+ if a.kwarg != None:
+ s += "**" + printFunction(a.kwarg, indent) + ", "
+ if a.vararg == None and a.kwarg == None and len(a.kwonlyargs) > 0:
+ s += "*, "
+ if len(a.kwonlyargs) > 0:
+ for i in range(len(a.kwonlyargs)):
+ s += printFunction(a.kwonlyargs[i], indent)
+ s += "=" + printFunction(a.kw_defaults, indent) + ", "
+ if (len(a.args) > 0 or a.vararg != None or a.kwarg != None or len(a.kwonlyargs) > 0):
+ s = s[:-2]
+ elif t == ast.arg:
+ s += a.arg
+ if a.annotation != None:
+ s += ": " + printFunction(a.annotation, indent)
+ elif t == ast.keyword:
+ s += a.arg + "=" + printFunction(a.value, indent)
+ elif t == ast.alias:
+ s += a.name
+ if a.asname != None:
+ s += " as " + a.asname
+ elif t == ast.withitem:
+ s += printFunction(a.context_expr, indent)
+ if a.optional_vars != None:
+ s += " as " + printFunction(a.optional_vars, indent)
+ else:
+ ops = { ast.And : "and", ast.Or : "or",
+ ast.Add : "+", ast.Sub : "-", ast.Mult : "*", ast.Div : "/", ast.Mod : "%",
+ ast.Pow : "**", ast.LShift : "<<", ast.RShift : ">>", ast.BitOr : "|",
+ ast.BitXor : "^", ast.BitAnd : "&", ast.FloorDiv : "//",
+ ast.Invert : "~", ast.Not : "not", ast.UAdd : "+", ast.USub : "-",
+ ast.Eq : "==", ast.NotEq : "!=", ast.Lt : "<", ast.LtE : "<=",
+ ast.Gt : ">", ast.GtE : ">=", ast.Is : "is", ast.IsNot : "is not",
+ ast.In : "in", ast.NotIn : "not in"}
+ if type(a) in ops:
+ return ops[type(a)]
+ if type(a) in [ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param]:
+ return ""
+ log("display\tMissing type: " + str(t), "bug")
+ return s
+
+def formatContext(trace, verb):
+ traceD = {
+ "value" : { "Return" : ("return statement"),
+ "Assign" : ("right side of the assignment"),
+ "AugAssign" : ("right side of the assignment"),
+ "Expression" : ("expression"),
+ "Dict Comprehension" : ("left value of the dict comprehension"),
+ "Yield" : ("yield expression"),
+ "Repr" : ("repr expression"),
+ "Attribute" : ("attribute value"),
+ "Subscript" : ("outer part of the subscript"),
+ "Index" : ("inner part of the subscript"),
+ "Keyword" : ("right side of the keyword"),
+ "Starred" : ("value of the starred expression"),
+ "Name Constant" : ("constant value") },
+ "values" : { "Print" : ("print statement"),
+ "Boolean Operation" : ("boolean operation"),
+ "Dict" : ("values of the dictionary") },
+ "name" : { "Function Definition" : ("function name"),
+ "Class Definition" : ("class name"),
+ "Except Handler" : ("name of the except statement"),
+ "Alias" : ("alias") },
+ "names" : { "Import" : ("import"),
+ "ImportFrom" : ("import"),
+ "Global" : ("global variables") },
+ "elt" : { "List Comprehension" : ("left element of the list comprehension"),
+ "Set Comprehension" : ("left element of the set comprehension"),
+ "Generator" : ("left element of the generator") },
+ "elts" : { "Set" : ("set"),
+ "List" : ("list"),
+ "Tuple" : ("tuple") },
+ "target" : { "AugAssign" : ("left side of the assignment"),
+ "For" : ("target of the for loop"),
+ "Comprehension" : ("target of the comprehension") },
+ "targets" : { "Delete" : ("delete statement"),
+ "Assign" : ("left side of the assignment") },
+ "op" : { "AugAssign" : ("assignment"),
+ "Boolean Operation" : ("boolean operation"),
+ "Binary Operation" : ("binary operation"),
+ "Unary Operation" : ("unary operation") },
+ "ops" : { "Compare" : ("comparison operation") },
+ "arg" : { "Keyword" : ("left side of the keyword"),
+ "Argument" : ("argument") },
+ "args" : { "Function Definition" : ("function arguments"), # single item
+ "Lambda" : ("lambda arguments"), # single item
+ "Call" : ("arguments of the function call"),
+ "Arguments" : ("function arguments") },
+ "key" : { "Dict Comprehension" : ("left key of the dict comprehension") },
+ "keys" : { "Dict" : ("keys of the dictionary") },
+ "kwarg" : { "Arguments" : ("keyword arg") },
+ "kwargs" : { "Call" : ("keyword args of the function call") }, # single item
+ "body" : { "Module" : ("main codebase"), # list
+ "Interactive" : ("main codebase"), # list
+ "Expression" : ("main codebase"),
+ "Suite" : ("main codebase"), # list
+ "Function Definition" : ("function body"), # list
+ "Class Definition" : ("class body"), # list
+ "For" : ("lines of the for loop"), # list
+ "While" : ("lines of the while loop"), # list
+ "If" : ("main lines of the if statement"), # list
+ "With" : ("lines of the with block"), # list
+ "Try" : ("lines of the try block"), # list
+ "Execute" : ("exec expression"),
+ "Lambda" : ("lambda body"),
+ "Ternary" : ("ternary body"),
+ "Except Handler" : ("lines of the except block") }, # list
+ "orelse" : { "For" : ("else part of the for loop"), # list
+ "While" : ("else part of the while loop"), # list
+ "If" : ("lines of the else statement"), # list
+ "Try" : ("lines of the else statement"), # list
+ "Ternary" : ("ternary else value") },
+ "test" : { "While" : ("test case of the while statement"),
+ "If" : ("test case of the if statement"),
+ "Assert" : ("assert expression"),
+ "Ternary" : ("test case of the ternary expression") },
+ "generators" : { "List Comprehension" : ("list comprehension"),
+ "Set Comprehension" : ("set comprehension"),
+ "Dict Comprehension" : ("dict comprehension"),
+ "Generator" : ("generator") },
+ "decorator_list" : { "Function Definition" : ("function decorators"), # list
+ "Class Definition" : ("class decorators") }, # list
+ "iter" : { "For" : ("iterator of the for loop"),
+ "Comprehension" : ("iterator of the comprehension") },
+ "type" : { "Raise" : ("raised type"),
+ "Except Handler" : ("type of the except statement") },
+ "left" : { "Binary Operation" : ("left side of the binary operation"),
+ "Compare" : ("left side of the comparison") },
+ "bases" : { "Class Definition" : ("class bases") },
+ "dest" : { "Print" : ("print destination") },
+ "nl" : { "Print" : ("comma at the end of the print statement") },
+ "context_expr" : { "With item" : ("context of the with statement") },
+ "optional_vars" : { "With item" : ("context of the with statement") }, # single item
+ "inst" : { "Raise" : ("raise expression") },
+ "tback" : { "Raise" : ("raise expression") },
+ "handlers" : { "Try" : ("except block") },
+ "finalbody" : { "Try" : ("finally block") }, # list
+ "msg" : { "Assert" : ("assert message") },
+ "module" : { "Import From" : ("import module") },
+ "level" : { "Import From" : ("import module") },
+ "globals" : { "Execute" : ("exec global value") }, # single item
+ "locals" : { "Execute" : ("exec local value") }, # single item
+ "right" : { "Binary Operation" : ("right side of the binary operation") },
+ "operand" : { "Unary Operation" : ("value of the unary operation") },
+ "comparators" : { "Compare" : ("right side of the comparison") },
+ "func" : { "Call" : ("function call") },
+ "keywords" : { "Call" : ("keywords of the function call") },
+ "starargs" : { "Call" : ("star args of the function call") }, # single item
+ "attr" : { "Attribute" : ("attribute of the value") },
+ "slice" : { "Subscript" : ("inner part of the subscript") },
+ "lower" : { "Slice" : ("left side of the subscript slice") },
+ "upper" : { "Slice" : ("right side of the subscript slice") },
+ "step" : { "Step" : ("rightmost side of the subscript slice") },
+ "dims" : { "ExtSlice" : ("slice") },
+ "ifs" : { "Comprehension" : ("if part of the comprehension") },
+ "vararg" : { "Arguments" : ("vararg") },
+ "defaults" : { "Arguments" : ("default values of the arguments") },
+ "asname" : { "Alias" : ("new name") },
+ "items" : { "With" : ("context of the with statement") }
+ }
+
+ # Find what type this is by trying to find the closest container in the path
+ i = 0
+ while i < len(trace):
+ if type(trace[i]) == tuple:
+ if trace[i][0] == "value" and trace[i][1] == "Attribute":
+ pass
+ elif trace[i][0] in traceD:
+ break
+ elif trace[i][0] in ["id", "n", "s"]:
+ pass
+ else:
+ log("display\tformatContext\tSkipped field: " + str(trace[i]), "bug")
+ i += 1
+ else:
+ return "" # this is probably covered by the line number
+
+ field,typ = trace[i]
+ if field in traceD and typ in traceD[field]:
+ context = traceD[field][typ]
+ return verb + "the " + context
+ else:
+ log("display\tformatContext\tMissing field: " + str(field) + "," + str(typ), "bug")
+ return ""
+
+def formatList(node, field):
+ if type(node) != list:
+ return None
+ s = ""
+ nameMap = { "body" : "line", "targets" : "value", "values" : "value", "orelse" : "line",
+ "names" : "name", "keys" : "key", "elts" : "value", "ops" : "operator",
+ "comparators" : "value", "args" : "argument", "keywords" : "keyword" }
+
+ # Find what type this is
+ itemType = nameMap[field] if field in nameMap else "line"
+
+ if len(node) > 1:
+ s = "the " + itemType + "s: "
+ for line in node:
+ s += formatNode(line) + ", "
+ elif len(node) == 1:
+ s = "the " + itemType + " "
+ f = formatNode(node[0])
+ if itemType == "line":
+ f = "[" + f + "]"
+ s += f
+ return s
+
+def formatNode(node):
+ """Create a string version of the given node"""
+ if node == None:
+ return ""
+ t = type(node)
+ if t == str:
+ return "'" + node + "'"
+ elif t == int or t == float:
+ return str(node)
+ elif t == list:
+ return formatList(node, None)
+ else:
+ return printFunction(node, 0) \ No newline at end of file
diff --git a/canonicalize/namesets.py b/canonicalize/namesets.py
new file mode 100644
index 0000000..5523f44
--- /dev/null
+++ b/canonicalize/namesets.py
@@ -0,0 +1,463 @@
+import ast, collections
+
+supportedLibraries = [ "string", "math", "random", "__future__", "copy" ]
+
+builtInTypes = [ bool, bytes, complex, dict, float, int, list,
+ set, str, tuple, type ]
+
+staticTypeCastBuiltInFunctions = {
+ "bool" : { (object,) : bool },
+ "bytes" : bytes,
+ "complex" : { (str, int) : complex, (str, float) : complex, (int, int) : complex,
+ (int, float) : complex, (float, int) : complex, (float, float) : complex },
+ "dict" : { (collections.Iterable,) : dict },
+ "enumerate" : { (collections.Iterable,) : enumerate },
+ "float" : { (str,) : float, (int,) : float, (float,) : float },
+ "frozenset" : frozenset,
+ "int" : { (str,) : int, (float,) : int, (int,) : int, (str, int) : int },
+ "list" : { (collections.Iterable,) : list },
+ "memoryview" : None, #TODO
+ "object" : { () : object },
+ "property" : property, #TODO
+ "reversed" : { (str,) : reversed, (list,) : reversed },
+ "set" : { () : None, (collections.Iterable,) : None }, #TODO
+ "slice" : { (int,) : slice, (int, int) : slice, (int, int, int) : slice },
+ "str" : { (object,) : str },
+ "tuple" : { () : tuple, (collections.Iterable,) : tuple },
+ "type" : { (object,) : type },
+ }
+
+mutatingTypeCastBuiltInFunctions = {
+ "bytearray" : { () : list },
+ "classmethod" : None,
+ "file" : None,
+ "staticmethod" : None, #TOOD
+ "super" : None
+ }
+
+builtInNames = [ "None", "True", "False", "NotImplemented", "Ellipsis" ]
+
+staticBuiltInFunctions = {
+ "abs" : { (int,) : int, (float,) : float },
+ "all" : { (collections.Iterable,) : bool },
+ "any" : { (collections.Iterable,) : bool },
+ "bin" : { (int,) : str },
+ "callable" : { (object,) : bool },
+ "chr" : { (int,) : str },
+ "cmp" : { (object, object) : int },
+ "coerce" : tuple, #TODO
+ "compile" : { (str, str, str) : ast, (ast, str, str) : ast },
+ "dir" : { () : list },
+ "divmod" : { (int, int) : tuple, (int, float) : tuple,
+ (float, int) : tuple, (float, float) : tuple },
+ "filter" : { (type(lambda x : x), collections.Iterable) : list },
+ "getattr" : None,
+ "globals" : dict,
+ "hasattr" : bool, #TODO
+ "hash" : int,
+ "hex" : str,
+ "id" : int, #TODO
+ "isinstance" : { (None, None) : bool },
+ "issubclass" : bool, #TODO
+ "iter" : { (collections.Iterable,) : None, (None, object) : None },
+ "len" : { (str,) : int, (tuple,) : int, (list,) : int, (dict,) : int },
+ "locals" : dict,
+ "map" : { (None, collections.Iterable) : list }, #TODO
+ "max" : { (collections.Iterable,) : None },
+ "min" : { (collections.Iterable,) : None },
+ "oct" : { (int,) : str },
+ "ord" : { (str,) : int },
+ "pow" : { (int, int) : int, (int, float) : float,
+ (float, int) : float, (float, float) : float },
+ "print" : None,
+ "range" : { (int,) : list, (int, int) : list, (int, int, int) : list },
+ "repr" : {(object,) : str },
+ "round" : { (int,) : float, (float,) : float, (int, int) : float, (float, int) : float },
+ "sorted" : { (collections.Iterable,) : list },
+ "sum" : { (collections.Iterable,) : None },
+ "vars" : dict, #TODO
+ "zip" : { () : list, (collections.Iterable,) : list}
+ }
+
+mutatingBuiltInFunctions = {
+ "__import__" : None,
+ "apply" : None,
+ "delattr" : { (object, str) : None },
+ "eval" : { (str,) : None },
+ "execfile" : None,
+ "format" : None,
+ "input" : { () : None, (object,) : None },
+ "intern" : str, #TOOD
+ "next" : { () : None, (None,) : None },
+ "open" : None,
+ "raw_input" : { () : str, (object,) : str },
+ "reduce" : None,
+ "reload" : None,
+ "setattr" : None
+ }
+
+builtInSafeFunctions = [
+ "abs", "all", "any", "bin", "bool", "cmp", "len",
+ "list", "max", "min", "pow", "repr", "round", "slice", "str", "type"
+ ]
+
+
+exceptionClasses = [
+ "ArithmeticError",
+ "AssertionError",
+ "AttributeError",
+ "BaseException",
+ "BufferError",
+ "BytesWarning",
+ "DeprecationWarning",
+ "EOFError",
+ "EnvironmentError",
+ "Exception",
+ "FloatingPointError",
+ "FutureWarning",
+ "GeneratorExit",
+ "IOError",
+ "ImportError",
+ "ImportWarning",
+ "IndentationError",
+ "IndexError",
+ "KeyError",
+ "KeyboardInterrupt",
+ "LookupError",
+ "MemoryError",
+ "NameError",
+ "NotImplementedError",
+ "OSError",
+ "OverflowError",
+ "PendingDeprecationWarning",
+ "ReferenceError",
+ "RuntimeError",
+ "RuntimeWarning",
+ "StandardError",
+ "StopIteration",
+ "SyntaxError",
+ "SyntaxWarning",
+ "SystemError",
+ "SystemExit",
+ "TabError",
+ "TypeError",
+ "UnboundLocalError",
+ "UnicodeDecodeError",
+ "UnicodeEncodeError",
+ "UnicodeError",
+ "UnicodeTranslateError",
+ "UnicodeWarning",
+ "UserWarning",
+ "ValueError",
+ "Warning",
+ "ZeroDivisionError",
+
+ "WindowsError", "BlockingIOError", "ChildProcessError",
+ "ConnectionError", "BrokenPipeError", "ConnectionAbortedError",
+ "ConnectionRefusedError", "ConnectionResetError", "FileExistsError",
+ "FileNotFoundError", "InterruptedError", "IsADirectoryError", "NotADirectoryError",
+ "PermissionError", "ProcessLookupError", "TimeoutError",
+ "ResourceWarning", "RecursionError", "StopAsyncIteration" ]
+
+builtInFunctions = dict(list(staticBuiltInFunctions.items()) + \
+ list(mutatingBuiltInFunctions.items()) + \
+ list(staticTypeCastBuiltInFunctions.items()) + \
+ list(mutatingTypeCastBuiltInFunctions.items()))
+
+# All string functions do not mutate the caller, they return copies instead
+builtInStringFunctions = {
+ "capitalize" : { () : str },
+ "center" : { (int,) : str, (int, str) : str },
+ "count" : { (str,) : int, (str, int) : int, (str, int, int) : int },
+ "decode" : { () : str },
+ "encode" : { () : str },
+ "endswith" : { (str,) : bool },
+ "expandtabs" : { () : str, (int,) : str },
+ "find" : { (str,) : int, (str, int) : int, (str, int, int) : int },
+ "format" : { (list, list) : str },
+ "index" : { (str,) : int, (str,int) : int, (str,int,int) : int },
+ "isalnum" : { () : bool },
+ "isalpha" : { () : bool },
+ "isdecimal" : { () : bool },
+ "isdigit" : { () : bool },
+ "islower" : { () : bool },
+ "isnumeric" : { () : bool },
+ "isspace" : { () : bool },
+ "istitle" : { () : bool },
+ "isupper" : { () : bool },
+ "join" : { (collections.Iterable,) : str, (collections.Iterable,str) : str },
+ "ljust" : { (int,) : str },
+ "lower" : { () : str },
+ "lstrip" : { () : str, (str,) : str },
+ "partition" : { (str,) : tuple },
+ "replace" : { (str, str) : str, (str, str, int) : str },
+ "rfind" : { (str,) : int, (str,int) : int, (str,int,int) : int },
+ "rindex" : { (str,) : int },
+ "rjust" : { (int,) : str },
+ "rpartition" : { (str,) : tuple },
+ "rsplit" : { () : list },
+ "rstrip" : { () : str },
+ "split" : { () : list, (str,) : list, (str, int) : list },
+ "splitlines" : { () : list },
+ "startswith" : { (str,) : bool },
+ "strip" : { () : str, (str,) : str },
+ "swapcase" : { () : str },
+ "title" : { () : str },
+ "translate" : { (str,) : str },
+ "upper" : { () : str },
+ "zfill" : { (int,) : str }
+ }
+
+safeStringFunctions = [
+ "capitalize", "center", "count", "endswith", "expandtabs", "find",
+ "isalnum", "isalpha", "isdigit", "islower", "isspace", "istitle",
+ "isupper", "join", "ljust", "lower", "lstrip", "partition", "replace",
+ "rfind", "rjust", "rpartition", "rsplit", "rstrip", "split", "splitlines",
+ "startswith", "strip", "swapcase", "title", "translate", "upper", "zfill",
+ "isdecimal", "isnumeric"]
+
+mutatingListFunctions = {
+ "append" : { (object,) : None },
+ "extend" : { (list,) : None },
+ "insert" : { (int, object) : None },
+ "remove" : { (object,) : None },
+ "pop" : { () : None, (int,) : None },
+ "sort" : { () : None },
+ "reverse" : { () : None }
+ }
+
+staticListFunctions = {
+ "index" : { (object,) : int, (object,int) : int, (object,int,int) : int },
+ "count" : { (object,) : int, (object,int) : int, (object,int,int) : int }
+ }
+
+safeListFunctions = [ "append", "extend", "insert", "count", "sort", "reverse"]
+
+builtInListFunctions = dict(list(mutatingListFunctions.items()) + list(staticListFunctions.items()))
+
+staticDictFunctions = {
+ "get" : { (object,) : object, (object, object) : object },
+ "items" : { () : list }
+ }
+
+builtInDictFunctions = staticDictFunctions
+
+mathFunctions = {
+ "ceil" : { (int,) : float, (float,) : float },
+ "copysign" : { (int, int) : float, (int, float) : float,
+ (float, int) : float, (float, float) : float },
+ "fabs" : { (int,) : float, (float,) : float },
+ "factorial" : { (int,) : int, (float,) : int },
+ "floor" : { (int,) : float, (float,) : float },
+ "fmod" : { (int, int) : float, (int, float) : float,
+ (float, int) : float, (float, float) : float },
+ "frexp" : int,
+ "fsum" : int, #TODO
+ "isinf" : { (int,) : bool, (float,) : bool },
+ "isnan" : { (int,) : bool, (float,) : bool },
+ "ldexp" : int,
+ "modf" : tuple,
+ "trunc" : None, #TODO
+ "exp" : { (int,) : float, (float,) : float },
+ "expm1" : { (int,) : float, (float,) : float },
+ "log" : { (int,) : float, (float,) : float,
+ (int,int) : float, (int,float) : float,
+ (float, int) : float, (float, float) : float },
+ "log1p" : { (int,) : float, (float,) : float },
+ "log10" : { (int,) : float, (float,) : float },
+ "pow" : { (int, int) : float, (int, float) : float,
+ (float, int) : float, (float, float) : float },
+ "sqrt" : { (int,) : float, (float,) : float },
+ "acos" : { (int,) : float, (float,) : float },
+ "asin" : { (int,) : float, (float,) : float },
+ "atan" : { (int,) : float, (float,) : float },
+ "atan2" : { (int,) : float, (float,) : float },
+ "cos" : { (int,) : float, (float,) : float },
+ "hypot" : { (int, int) : float, (int, float) : float,
+ (float, int) : float, (float, float) : float },
+ "sin" : { (int,) : float, (float,) : float },
+ "tan" : { (int,) : float, (float,) : float },
+ "degrees" : { (int,) : float, (float,) : float },
+ "radians" : { (int,) : float, (float,) : float },
+ "acosh" : int,
+ "asinh" : int,
+ "atanh" : int,
+ "cosh" : int,
+ "sinh" : int,
+ "tanh" : int,#TODO
+ "erf" : int,
+ "erfc" : int,
+ "gamma" : int,
+ "lgamma" : int #TODO
+ }
+
+safeMathFunctions = [
+ "ceil", "copysign", "fabs", "floor", "fmod", "isinf",
+ "isnan", "exp", "expm1", "cos", "hypot", "sin", "tan",
+ "degrees", "radians" ]
+
+randomFunctions = {
+ "seed" : { () : None, (collections.Hashable,) : None },
+ "getstate" : { () : object },
+ "setstate" : { (object,) : None },
+ "jumpahead" : { (int,) : None },
+ "getrandbits" : { (int,) : int },
+ "randrange" : { (int,) : int, (int, int) : int, (int, int, int) : int },
+ "randint" : { (int, int) : int },
+ "choice" : { (collections.Iterable,) : object },
+ "shuffle" : { (collections.Iterable,) : None,
+ (collections.Iterable, type(lambda x : x)) : None },
+ "sample" : { (collections.Iterable, int) : list },
+ "random" : { () : float },
+ "uniform" : { (float, float) : float }
+ }
+
+futureFunctions = {
+ "nested_scopes" : None,
+ "generators" : None,
+ "division" : None,
+ "absolute_import" : None,
+ "with_statement" : None,
+ "print_function" : None,
+ "unicode_literals" : None
+ }
+
+copyFunctions = {
+ "copy" : None,
+ "deepcopy" : None
+}
+
+timeFunctions = {
+ "clock" : { () : float },
+ "time" : { () : float }
+ }
+
+errorFunctions = {
+ "AssertionError" : { (str,) : object }
+ }
+
+allStaticFunctions = dict(list(staticBuiltInFunctions.items()) + list(staticTypeCastBuiltInFunctions.items()) + \
+ list(builtInStringFunctions.items()) + list(staticListFunctions.items()) + \
+ list(staticDictFunctions.items()) + list(mathFunctions.items()))
+
+allMutatingFunctions = dict(list(mutatingBuiltInFunctions.items()) + list(mutatingTypeCastBuiltInFunctions.items()) + \
+ list(mutatingListFunctions.items()) + list(randomFunctions.items()) + list(timeFunctions.items()))
+
+allPythonFunctions = dict(list(allStaticFunctions.items()) + list(allMutatingFunctions.items()))
+
+safeLibraryMap = { "string" : [ "ascii_letters", "ascii_lowercase", "ascii_uppercase",
+ "digits", "hexdigits", "letters", "lowercase", "octdigits",
+ "punctuation", "printable", "uppercase", "whitespace",
+ "capitalize", "expandtabs", "find", "rfind", "count",
+ "lower", "split", "rsplit", "splitfields", "join",
+ "joinfields", "lstrip", "rstrip", "strip", "swapcase",
+ "upper", "ljust", "rjust", "center", "zfill", "replace"],
+ "math" : [ "ceil", "copysign", "fabs", "floor", "fmod",
+ "frexp", "fsum", "isinf", "isnan", "ldexp", "modf", "trunc", "exp",
+ "expm1", "log", "log1p", "log10", "sqrt", "acos", "asin",
+ "atan", "atan2", "cos", "hypot", "sin", "tan", "degrees", "radians",
+ "acosh", "asinh", "atanh", "cosh", "sinh", "tanh", "erf", "erfc",
+ "gamma", "lgamma", "pi", "e" ],
+ "random" : [ ],
+ "__future__" : ["nested_scopes", "generators", "division", "absolute_import",
+ "with_statement", "print_function", "unicode_literals"] }
+
+libraryMap = { "string" : [ "ascii_letters", "ascii_lowercase", "ascii_uppercase",
+ "digits", "hexdigits", "letters", "lowercase", "octdigits",
+ "punctuation", "printable", "uppercase", "whitespace",
+ "capwords", "maketrans", "atof", "atoi", "atol", "capitalize",
+ "expandtabs", "find", "rfind", "index", "rindex", "count",
+ "lower", "split", "rsplit", "splitfields", "join",
+ "joinfields", "lstrip", "rstrip", "strip", "swapcase",
+ "translate", "upper", "ljust", "rjust", "center", "zfill",
+ "replace", "Template", "Formatter" ],
+ "math" : [ "ceil", "copysign", "fabs", "factorial", "floor", "fmod",
+ "frexp", "fsum", "isinf", "isnan", "ldexp", "modf", "trunc", "exp",
+ "expm1", "log", "log1p", "log10", "pow", "sqrt", "acos", "asin",
+ "atan", "atan2", "cos", "hypot", "sin", "tan", "degrees", "radians",
+ "acosh", "asinh", "atanh", "cosh", "sinh", "tanh", "erf", "erfc",
+ "gamma", "lgamma", "pi", "e" ],
+ "random" : ["seed", "getstate", "setstate", "jumpahead", "getrandbits",
+ "randrange", "randrange", "randint", "choice", "shuffle", "sample",
+ "random", "uniform", "triangular", "betavariate", "expovariate",
+ "gammavariate", "gauss", "lognormvariate", "normalvariate",
+ "vonmisesvariate", "paretovariate", "weibullvariate", "WichmannHill",
+ "whseed", "SystemRandom" ],
+ "__future__" : ["nested_scopes", "generators", "division", "absolute_import",
+ "with_statement", "print_function", "unicode_literals"],
+ "copy" : ["copy", "deepcopy"] }
+
+libraryDictMap = { "string" : builtInStringFunctions,
+ "math" : mathFunctions,
+ "random" : randomFunctions,
+ "__future__" : futureFunctions,
+ "copy" : copyFunctions }
+
+typeMethodMap = {"string" : ["capitalize", "center", "count", "decode", "encode", "endswith",
+ "expandtabs", "find", "format", "index", "isalnum", "isalpha",
+ "isdigit", "islower", "isspace", "istitle", "isupper", "join",
+ "ljust", "lower", "lstrip", "partition", "replace", "rfind",
+ "rindex", "rjust", "rpartition", "rsplit", "rstrip", "split",
+ "splitlines", "startswith", "strip", "swapcase", "title",
+ "translate", "upper", "zfill"],
+ "list" : [ "append", "extend", "count", "index", "insert", "pop", "remove",
+ "reverse", "sort"],
+ "set" : [ "isdisjoint", "issubset", "issuperset", "union", "intersection",
+ "difference", "symmetric_difference", "update", "intersection_update",
+ "difference_update", "symmetric_difference_update", "add",
+ "remove", "discard", "pop", "clear"],
+ "dict" : [ "iter", "clear", "copy", "fromkeys", "get", "has_key", "items",
+ "iteritems", "iterkeys", "itervalues", "keys", "pop", "popitem",
+ "setdefault", "update", "values", "viewitems", "viewkeys",
+ "viewvalues"] }
+
+astNames = {
+ ast.Module : "Module", ast.Interactive : "Interactive Module",
+ ast.Expression : "Expression Module", ast.Suite : "Suite",
+
+ ast.FunctionDef : "Function Definition",
+ ast.ClassDef : "Class Definition", ast.Return : "Return",
+ ast.Delete : "Delete", ast.Assign : "Assign",
+ ast.AugAssign : "AugAssign", ast.For : "For",
+ ast.While : "While", ast.If : "If", ast.With : "With",
+ ast.Raise : "Raise",
+ ast.Try : "Try", ast.Assert : "Assert",
+ ast.Import : "Import", ast.ImportFrom : "Import From",
+ ast.Global : "Global", ast.Expr : "Expression",
+ ast.Pass : "Pass", ast.Break : "Break", ast.Continue : "Continue",
+
+ ast.BoolOp : "Boolean Operation", ast.BinOp : "Binary Operation",
+ ast.UnaryOp : "Unary Operation", ast.Lambda : "Lambda",
+ ast.IfExp : "Ternary", ast.Dict : "Dictionary", ast.Set : "Set",
+ ast.ListComp : "List Comprehension", ast.SetComp : "Set Comprehension",
+ ast.DictComp : "Dict Comprehension",
+ ast.GeneratorExp : "Generator", ast.Yield : "Yield",
+ ast.Compare : "Compare", ast.Call : "Call",
+ ast.Num : "Number", ast.Str : "String", ast.Bytes : "Bytes",
+ ast.NameConstant : "Name Constant",
+ ast.Attribute : "Attribute",
+ ast.Subscript : "Subscript", ast.Name : "Name", ast.List : "List",
+ ast.Tuple : "Tuple", ast.Starred : "Starred",
+
+ ast.Load : "Load", ast.Store : "Store", ast.Del : "Delete",
+ ast.AugLoad : "AugLoad", ast.AugStore : "AugStore",
+ ast.Param : "Parameter",
+
+ ast.Ellipsis : "Ellipsis", ast.Slice : "Slice",
+ ast.ExtSlice : "ExtSlice", ast.Index : "Index",
+
+ ast.And : "And", ast.Or : "Or", ast.Add : "Add", ast.Sub : "Subtract",
+ ast.Mult : "Multiply", ast.Div : "Divide", ast.Mod : "Modulo",
+ ast.Pow : "Power", ast.LShift : "Left Shift",
+ ast.RShift : "Right Shift", ast.BitOr : "|", ast.BitXor : "^",
+ ast.BitAnd : "&", ast.FloorDiv : "Integer Divide",
+ ast.Invert : "Invert", ast.Not : "Not", ast.UAdd : "Unsigned Add",
+ ast.USub : "Unsigned Subtract", ast.Eq : "==", ast.NotEq : "!=",
+ ast.Lt : "<", ast.LtE : "<=", ast.Gt : ">", ast.GtE : ">=",
+ ast.Is : "Is", ast.IsNot : "Is Not", ast.In : "In",
+ ast.NotIn : "Not In",
+
+ ast.comprehension: "Comprehension",
+ ast.ExceptHandler : "Except Handler", ast.arguments : "Arguments", ast.arg : "Argument",
+ ast.keyword : "Keyword", ast.alias : "Alias", ast.withitem : "With item"
+ } \ No newline at end of file
diff --git a/canonicalize/tools.py b/canonicalize/tools.py
new file mode 100644
index 0000000..1df653a
--- /dev/null
+++ b/canonicalize/tools.py
@@ -0,0 +1,15 @@
+"""This is a file of useful functions used throughout the hint generation program"""
+import time, os.path, ast, json
+
+def log(msg, filename="main", newline=True):
+ return
+ txt = ""
+ if newline:
+ t = time.strftime("%d %b %Y %H:%M:%S")
+ txt += t + "\t"
+ txt += msg
+ if newline:
+ txt += "\n"
+ f = open('log/' + filename + ".log", "a")
+ f.write(txt)
+ f.close()
diff --git a/canonicalize/transformations.py b/canonicalize/transformations.py
new file mode 100644
index 0000000..7ae1e40
--- /dev/null
+++ b/canonicalize/transformations.py
@@ -0,0 +1,2775 @@
+import ast, copy, functools
+
+from .tools import log
+from .namesets import *
+from .astTools import *
+
+### VARIABLE ANONYMIZATION ###
+
+def updateVariableNames(a, varMap, scopeName, randomCounter, imports):
+ if not isinstance(a, ast.AST):
+ return
+
+ if type(a) in [ast.FunctionDef, ast.ClassDef]:
+ if a.name in varMap:
+ if not hasattr(a, "originalId"):
+ a.originalId = a.name
+ a.name = varMap[a.name]
+ anonymizeStatementNames(a, varMap, "_" + a.name, imports)
+ elif type(a) == ast.arg:
+ if a.arg not in varMap and not (builtInName(a.arg) or importedName(a.arg, imports)):
+ log("Can't assign to arg?", "bug")
+ if a.arg in varMap:
+ if not hasattr(a, "originalId"):
+ a.originalId = a.arg
+ if varMap[a.arg][0] == "r":
+ a.randomVar = True # so we know it can crash
+ if a.arg == varMap[a.arg]:
+ # Check whether this is a given name
+ if not isAnonVariable(varMap[a.arg]):
+ a.dontChangeName = True
+ a.arg = varMap[a.arg]
+ elif type(a) == ast.Name:
+ if a.id not in varMap and not (builtInName(a.id) or importedName(a.id, imports)):
+ varMap[a.id] = "r" + str(randomCounter[0]) + scopeName
+ randomCounter[0] += 1
+ if a.id in varMap:
+ if not hasattr(a, "originalId"):
+ a.originalId = a.id
+ if varMap[a.id][0] == "r":
+ a.randomVar = True # so we know it can crash
+ if a.id == varMap[a.id]:
+ # Check whether this is a given name
+ if not isAnonVariable(varMap[a.id]):
+ a.dontChangeName = True
+ a.id = varMap[a.id]
+ else:
+ for child in ast.iter_child_nodes(a):
+ updateVariableNames(child, varMap, scopeName, randomCounter, imports)
+
+def gatherLocalScope(a, globalMap, scopeName, imports, goBackwards=False):
+ localMap = { }
+ varLetter = "g" if type(a) == ast.Module else "v"
+ paramCounter = 0
+ localCounter = 0
+ if type(a) == ast.FunctionDef:
+ for param in a.args.args:
+ if type(param) == ast.arg:
+ if not (builtInName(param.arg) or importedName(param.arg, imports)) and param.arg not in localMap and param.arg not in globalMap:
+ localMap[param.arg] = "p" + str(paramCounter) + scopeName
+ paramCounter += 1
+ else:
+ log("transformations\tanonymizeFunctionNames\tWeird parameter type: " + str(type(param)), "bug")
+ if goBackwards:
+ items = a.body[::-1] # go through this backwards to get the used function names
+ else:
+ items = a.body[:]
+ # First, go through and create the globalMap
+ while len(items) > 0:
+ item = items[0]
+ if type(item) in [ast.FunctionDef, ast.ClassDef]:
+ if not (builtInName(item.name) or importedName(item.name, imports)):
+ if item.name not in localMap and item.name not in globalMap:
+ localMap[item.name] = "helper_" + varLetter + str(localCounter) + scopeName
+ localCounter += 1
+ else:
+ item.dontChangeName = True
+
+ # If there are any variables in this node, find and label them
+ if type(item) in [ ast.Assign, ast.AugAssign, ast.For, ast.With,
+ ast.Lambda, ast.comprehension, ast.ExceptHandler]:
+ assns = []
+ if type(item) == ast.Assign:
+ assns = item.targets
+ elif type(item) == ast.AugAssign:
+ assns = [item.target]
+ elif type(item) == ast.For:
+ assns = [item.target]
+ elif type(item) == ast.With:
+ if item.optional_vars != None:
+ assns = [item.optional_vars]
+ elif type(item) == ast.Lambda:
+ assns = item.args.args
+ elif type(item) == ast.comprehension:
+ assns = [item.target]
+ elif type(item) == ast.ExceptHandler:
+ if item.name != None:
+ assns = [item.name]
+
+ for assn in assns:
+ if type(assn) == ast.Name:
+ if not (builtInName(assn.id) or importedName(assn.id, imports)):
+ if assn.id not in localMap and assn.id not in globalMap:
+ localMap[assn.id] = varLetter + str(localCounter) + scopeName
+ localCounter += 1
+
+ if type(item) in [ast.For, ast.While, ast.If]:
+ items += item.body + item.orelse
+ elif type(item) in [ast.With, ast.ExceptHandler]:
+ items += item.body
+ elif type(item) == ast.Try:
+ items += item.body + item.handlers + item.orelse + item.finalbody
+ items = items[1:]
+ return localMap
+
+def anonymizeStatementNames(a, globalMap, scopeName, imports, goBackwards=False):
+ """Gather the local variables, then update variable names in each line"""
+ localMap = gatherLocalScope(a, globalMap, scopeName, imports, goBackwards=goBackwards)
+ varMap = { }
+ varMap.update(globalMap)
+ varMap.update(localMap)
+ randomCounter = [0]
+ functionsSeen = []
+ if type(a) == ast.FunctionDef:
+ for arg in a.args.args:
+ updateVariableNames(arg, varMap, scopeName, randomCounter, imports)
+ for line in a.body:
+ updateVariableNames(line, varMap, scopeName, randomCounter, imports)
+
+def anonymizeNames(a, namesToKeep, imports):
+ """Anonymize all of variables/names that occur in the given AST"""
+ """If we run this on an anonymized AST, it will fix the names again to get rid of any gaps!"""
+ if type(a) != ast.Module:
+ return a
+ globalMap = { }
+ for var in namesToKeep:
+ globalMap[var] = var
+ anonymizeStatementNames(a, globalMap, "", imports, goBackwards=True)
+ return a
+
+def propogateNameMetadata(a, namesToKeep, imports):
+ """Propogates name metadata through a state. We assume that the names are all properly formatted"""
+ if type(a) == list:
+ for child in a:
+ child = propogateNameMetadata(child, namesToKeep, imports)
+ return a
+ elif not isinstance(a, ast.AST):
+ return a
+ if type(a) == ast.Name:
+ if (builtInName(a.id) or importedName(a.id, imports)):
+ pass
+ elif a.id in namesToKeep:
+ a.dontChangeName = True
+ else:
+ if not hasattr(a, "originalId"):
+ a.originalId = a.id
+ if not isAnonVariable(a.id):
+ a.dontChangeName = True # it's a name we shouldn't mess with
+ elif type(a) == ast.arg:
+ if (builtInName(a.arg) or importedName(a.arg, imports)):
+ pass
+ elif a.arg in namesToKeep:
+ a.dontChangeName = True
+ else:
+ if not hasattr(a, "originalId"):
+ a.originalId = a.arg
+ if not isAnonVariable(a.arg):
+ a.dontChangeName = True # it's a name we shouldn't mess with
+ for child in ast.iter_child_nodes(a):
+ child = propogateNameMetadata(child, namesToKeep, imports)
+ return a
+
+### HELPER FOLDING ###
+
+def findHelperFunction(a, helperId, helperCount):
+ """Finds the first helper function used in the ast"""
+ if not isinstance(a, ast.AST):
+ return None
+
+ # Check all the children, so that we don't end up with a recursive problem
+ for child in ast.iter_child_nodes(a):
+ f = findHelperFunction(child, helperId, helperCount)
+ if f != None:
+ return f
+ # Then check if this is the right call
+ if type(a) == ast.Call:
+ if type(a.func) == ast.Name and a.func.id == helperId:
+ if helperCount[0] > 0:
+ helperCount[0] -= 1
+ else:
+ return a
+ return None
+
+def individualizeVariables(a, variablePairs, idNum, imports):
+ """Replace variable names with new individualized ones (for inlining methods)"""
+ if not isinstance(a, ast.AST):
+ return
+
+ if type(a) == ast.Name:
+ if a.id not in variablePairs and not (builtInName(a.id) or importedName(a.id, imports)):
+ name = "_var_" + a.id + "_" + str(idNum[0])
+ variablePairs[a.id] = name
+ if a.id in variablePairs:
+ a.id = variablePairs[a.id]
+ # Override built-in names when they're assigned to.
+ elif type(a) == ast.Assign and type(a.targets[0]) == ast.Name:
+ if a.targets[0].id not in variablePairs:
+ name = "_var_" + a.targets[0].id + "_" + str(idNum[0])
+ variablePairs[a.targets[0].id] = name
+ elif type(a) == ast.arguments:
+ for arg in a.args:
+ if type(arg) == ast.arg:
+ name = "_arg_" + arg.arg + "_" + str(idNum[0])
+ variablePairs[arg.arg] = name
+ arg.arg = variablePairs[arg.arg]
+ return
+ elif type(a) == ast.Call:
+ if type(a.func) == ast.Name:
+ variablePairs[a.func.id] = a.func.id # save the function name!
+
+ for child in ast.iter_child_nodes(a):
+ individualizeVariables(child, variablePairs, idNum, imports)
+
+def replaceAst(a, old, new, shouldReplace):
+ """Replace the old value with the new one, but only once! That's what the shouldReplace variable is for."""
+ if not isinstance(a, ast.AST) or not shouldReplace[0]:
+ return a
+ elif compareASTs(a, old, checkEquality=True) == 0:
+ shouldReplace[0] = False
+ return new
+ return applyToChildren(a, lambda x: replaceAst(x, old, new, shouldReplace))
+
+def mapHelper(a, helper, idNum, imports):
+ """Map the helper function into this function, if it's used. idNum gives us the current id"""
+ if type(a) in [ast.FunctionDef, ast.ClassDef]:
+ a.body = mapHelper(a.body, helper, idNum, imports)
+ elif type(a) in [ast.For, ast.While, ast.If]:
+ a.body = mapHelper(a.body, helper, idNum, imports)
+ a.orelse = mapHelper(a.orelse, helper, idNum, imports)
+ if type(a) != list:
+ return a # we only deal with lists
+
+ i = 0
+ body = a
+ while i < len(body):
+ if type(body[i]) in [ast.FunctionDef, ast.ClassDef, ast.For, ast.While, ast.If]:
+ body[i] = mapHelper(body[i], helper, idNum, imports) # deal with blocks first
+ # While we still need to replace a function being called
+ helperCount = 0
+ while countVariables(body[i], helper.name) > helperCount:
+ callExpression = findHelperFunction(body[i], helper.name, [helperCount])
+ if callExpression == None:
+ break
+ idNum[0] += 1
+ variablePairs = {}
+
+ # First, update the method's variables
+ methodArgs = deepcopy(helper.args)
+ individualizeVariables(methodArgs, variablePairs, idNum, imports)
+ methodLines = deepcopyList(helper.body)
+ for j in range(len(methodLines)):
+ individualizeVariables(methodLines[j], variablePairs, idNum, imports)
+
+ # Then, move each of the parameters into an assignment
+ argLines = []
+ badArgs = False
+ for j in range(len(callExpression.args)):
+ arg = methodArgs.args[j]
+ value = callExpression.args[j]
+ # If it only occurs once, and in the return statement..
+ if countVariables(methodLines, arg.arg) == 1 and countVariables(methodLines[-1], arg.arg) == 1:
+ var = ast.Name(arg.arg, ast.Load())
+ transferMetaData(arg, var)
+ methodLines[-1] = replaceAst(methodLines[-1], var, value, [True])
+ elif couldCrash(value):
+ badArgs = True
+ break
+ else:
+ if type(arg) == ast.arg:
+ var = ast.Name(arg.arg, ast.Store(), helperVar=True)
+ transferMetaData(arg, var)
+ varAssign = ast.Assign([var], value, global_id=arg.global_id, helperParamAssign=True)
+ argLines.append(varAssign)
+ else: # what other type would it be?
+ log("transformations\tmapHelper\tWeird arg: " + str(type(arg)), "bug")
+ if badArgs:
+ helperCount += 1
+ continue
+
+ # Finally, carry over the final value into the correct position
+ if countOccurances(helper, ast.Return) == 1:
+ returnLine = [replaceAst(body[i], callExpression, methodLines[-1].value, [True])]
+ methodLines.pop(-1)
+ else:
+ comp = compareASTs(body[i].value, callExpression, checkEquality=True) == 0
+ if type(body[i]) == ast.Return and comp:
+ # In the case of a return, we can return None, it's cool
+ callVal = ast.NameConstant(None, helperReturnVal=True)
+ transferMetaData(callExpression, callVal)
+ returnVal = ast.Return(callVal, global_id=body[i].global_id, helperReturnAssign=True)
+ returnLine = [returnVal]
+ elif type(body[i]) == ast.Assign and comp:
+ # Assigns can be assigned to None
+ callVal = ast.NameConstant(None, helperReturnVal=True)
+ transferMetaData(callExpression, callVal)
+ assignVal = ast.Assign(body[i].targets, callVal, global_id=body[i].global_id, helperReturnAssign=True)
+ returnLine = [assignVal]
+ elif type(body[i]) == ast.Expr and comp:
+ # Delete the line if it's by itself
+ returnLine = []
+ else:
+ # Otherwise, what is going on?!?
+ returnLine = []
+ log("transformations\tmapHelper\tWeird removeCall: " + str(type(body[i])), "bug")
+ #log("transformations\tmapHelper\tMapped helper function: " + helper.name, "bug")
+ body[i:i+1] = argLines + methodLines + returnLine
+ i += 1
+ return a
+
+def mapVariable(a, varId, assn):
+ """Map the variable assignment into the function, if it's needed"""
+ if type(a) != ast.FunctionDef:
+ return a
+
+ for arg in a.args.args:
+ if arg.arg == varId:
+ return a # overriden by local variable
+ for i in range(len(a.body)):
+ line = a.body[i]
+ if type(line) == ast.Assign:
+ for target in line.targets:
+ if type(target) == ast.Name and target.id == varId:
+ break
+ elif type(target) in [ast.Tuple, ast.List]:
+ for elt in target.elts:
+ if type(elt) == ast.Name and elt.id == varId:
+ break
+ if countVariables(line, varId) > 0:
+ a.body[i:i+1] = [deepcopy(assn), line]
+ break
+ return a
+
+def helperFolding(a, mainFun, imports):
+ """When possible, fold the functions used in this module into their callers"""
+ if type(a) != ast.Module:
+ return a
+
+ globalCounter = [0]
+ body = a.body
+ i = 0
+ while i < len(body):
+ item = body[i]
+ if type(item) == ast.FunctionDef:
+ if item.name != mainFun:
+ # We only want non-recursive functions that have a single return which occurs at the end
+ # Also, we don't want any functions with parameters that are changed during the function
+ returnOccurs = countOccurances(item, ast.Return)
+ if countVariables(item, item.name) == 0 and returnOccurs <= 1 and type(item.body[-1]) == ast.Return:
+ allArgs = []
+ for arg in item.args.args:
+ allArgs.append(arg.arg)
+ for tmpA in ast.walk(item):
+ if type(tmpA) in [ast.Assign, ast.AugAssign]:
+ allGood = True
+ assignedIds = gatherAssignedVarIds(tmpA.targets if type(tmpA) == ast.Assign else [tmpA.target])
+ for arg in allArgs:
+ if arg in assignedIds:
+ allGood = False
+ break
+ if not allGood:
+ break
+ else:
+ for j in range(len(item.body)-1):
+ if couldCrash(item.body[j]):
+ break
+ else:
+ # If we satisfy these requirements, translate the body of the function into all functions that call it
+ gone = True
+ used = False
+ for j in range(len(body)):
+ if i != j and type(body[j]) == ast.FunctionDef:
+ if countVariables(body[j], item.name) > 0:
+ used = True
+ mapHelper(body[j], item, globalCounter, imports)
+ if countVariables(body[j], item.name) > 0:
+ gone = False
+ if used and gone:
+ body.pop(i)
+ continue
+ elif type(item) == ast.Assign:
+ # Is it ever changed in the global area?
+ if len(item.targets) == 1 and type(item.targets[0]) == ast.Name:
+ if eventualType(item.value) in [int, float, bool, str]: # if it isn't mutable
+ for j in range(i+1, len(body)):
+ line = body[j]
+ if type(line) == ast.FunctionDef:
+ if countOccurances(line, ast.Global) > 0: # TODO: improve this
+ break
+ if item.targets[0].id in getAllAssignedVarIds(line):
+ break
+ else: # if in scope
+ if countVariables(line, item.targets[0].id) > 0:
+ break
+ else:
+ # Variable never appears again- we can map it!
+ for j in range(i+1, len(body)):
+ line = body[j]
+ if type(line) == ast.FunctionDef:
+ mapVariable(line, item.targets[0].id, item)
+ body.pop(i)
+ continue
+ i += 1
+ return a
+
+### AST PREPARATION ###
+
+def listNotEmpty(a):
+ """Determines that the iterable is NOT empty, if we can know that"""
+ """Used for For objects"""
+ if not isinstance(a, ast.AST):
+ return False
+ if type(a) == ast.Call:
+ if type(a.func) == ast.Name and a.func.id in ["range"]:
+ if len(a.args) == 1: # range(x)
+ return type(a.args[0]) == ast.Num and type(a.args[0].n) != complex and a.args[0].n > 0
+ elif len(a.args) == 2: # range(start, x)
+ if type(a.args[0]) == ast.Num and type(a.args[1]) == ast.Num and \
+ type(a.args[0].n) != complex and type(a.args[1].n) != complex and \
+ a.args[0].n < a.args[1].n:
+ return True
+ elif type(a.args[1]) == ast.BinOp and type(a.args[1].op) == ast.Add:
+ if type(a.args[1].right) == ast.Num and type(a.args[1].right) != complex and a.args[1].right.n > 0 and \
+ compareASTs(a.args[0], a.args[1].left, checkEquality=True) == 0:
+ return True
+ elif type(a.args[1].left) == ast.Num and type(a.args[1].left) != complex and a.args[1].left.n > 0 and \
+ compareASTs(a.args[0], a.args[1].right, checkEquality=True) == 0:
+ return True
+ elif type(a) in [ast.List, ast.Tuple]:
+ return len(a.elts) > 0
+ elif type(a) == ast.Str:
+ return len(a.s) > 0
+ return False
+
+def simplifyUpdateId(var, variableMap, idNum):
+ """Update the varID of a new variable"""
+ if type(var) not in [ast.Name, ast.arg]:
+ return var
+ idVar = var.id if type(var) == ast.Name else var.arg
+ if not hasattr(var, "varID"):
+ if idVar in variableMap:
+ var.varID = variableMap[idVar][1]
+ else:
+ var.varID = idNum[0]
+ idNum[0] += 1
+
+def simplify_multicomp(a):
+ if type(a) == ast.Compare and len(a.ops) > 1:
+ # Only do one comparator at a time. If we don't do this, things get messy!
+ comps = [a.left] + a.comparators
+ values = [ ]
+ # Compare each of the pairs
+ for i in range(len(a.ops)):
+ if i > 0:
+ # Label all nodes as middle parts so we can recognize them later
+ assignPropertyToAll(comps[i], "multiCompMiddle")
+ values.append(ast.Compare(comps[i], [a.ops[i]], [deepcopy(comps[i+1])], multiCompPart=True))
+ # Combine comparisons with and operators
+ boolOp = ast.And(multiCompOp=True)
+ boolopVal = ast.BoolOp(boolOp, values, multiComp=True, global_id=a.global_id)
+ return boolopVal
+ return a
+
+def simplify(a):
+ """This function simplifies the usual Python AST to make it usable by our functions."""
+ if not isinstance(a, ast.AST):
+ return a
+ elif type(a) == ast.Assign:
+ if len(a.targets) > 1:
+ # Go through all targets and assign them on separate lines
+ lines = [ast.Assign([a.targets[-1]], a.value, global_id=a.global_id)]
+ for i in range(len(a.targets)-1, 0, -1):
+ t = a.targets[i]
+ if type(t) == ast.Name:
+ loadedTarget = ast.Name(t.id, ast.Load())
+ elif type(t) == ast.Subscript:
+ loadedTarget = ast.Subscript(deepcopy(t.value), deepcopy(t.slice), ast.Load())
+ elif type(t) == ast.Attribute:
+ loadedTarget = ast.Attribute(deepcopy(t.value), t.attr, ast.Load())
+ elif type(t) == ast.Tuple:
+ loadedTarget = ast.Tuple(deepcopy(t.elts), ast.Load())
+ elif type(t) == ast.List:
+ loadedTarget = ast.List(deepcopy(t.elts), ast.Load())
+ else:
+ log("transformations\tsimplify\tOdd loadedTarget: " + str(type(t)), "bug")
+ transferMetaData(t, loadedTarget)
+ loadedTarget.global_id = t.global_id
+
+ lines.append(ast.Assign([a.targets[i-1]], loadedTarget, global_id=a.global_id))
+ else:
+ lines = [a]
+
+ i = 0
+ while i < len(lines):
+ # For each line, figure out type and varID
+ lines[i].value = simplify(lines[i].value)
+ t = lines[i].targets[0]
+ if type(t) in [ast.Tuple, ast.List]:
+ val = lines[i].value
+ # If the items are being assigned separately, with no dependance on each other,
+ # separate out the elements of the tuple
+ if type(val) in [ast.Tuple, ast.List] and len(t.elts) == len(val.elts) and \
+ allVariableNamesUsed(val) == []:
+ listLines = []
+ for j in range(len(t.elts)):
+ assignVal = ast.Assign([t.elts[j]], val.elts[j], global_id = lines[i].global_id)
+ listLines += simplify(assignVal)
+ lines[i:i+1] = listLines
+ i += len(listLines) - 1
+ i += 1
+ return lines
+ elif type(a) == ast.AugAssign:
+ # Turn all AugAssigns into Assigns
+ a.target = simplify(a.target)
+ if eventualType(a.target) not in [bool, int, str, float]:
+ # Can't get rid of AugAssign, in case the += is different
+ a.value = simplify(a.value)
+ return a
+ if type(a.target) == ast.Name:
+ loadedTarget = ast.Name(a.target.id, ast.Load())
+ elif type(a.target) == ast.Subscript:
+ loadedTarget = ast.Subscript(deepcopy(a.target.value), deepcopy(a.target.slice), ast.Load())
+ elif type(a.target) == ast.Attribute:
+ loadedTarget = ast.Attribute(deepcopy(a.target.value), a.target.attr, ast.Load())
+ elif type(a.target) == ast.Tuple:
+ loadedTarget = ast.Tuple(deepcopy(a.target.elts), ast.Load())
+ elif type(a.target) == ast.List:
+ loadedTarget = ast.List(deepcopy(a.target.elts), ast.Load())
+ else:
+ log("transformations\tsimplify\tOdd AugAssign target: " + str(type(a.target)), "bug")
+ transferMetaData(a.target, loadedTarget)
+ loadedTarget.global_id = a.target.global_id
+ a.target.augAssignVal = True # for later recognition
+ loadedTarget.augAssignVal = True
+ assignVal = ast.Assign([a.target], ast.BinOp(loadedTarget, a.op, a.value, augAssignBinOp=True), global_id=a.global_id)
+ return simplify(assignVal)
+ elif type(a) == ast.Compare and len(a.ops) > 1:
+ return simplify(simplify_multicomp(a))
+ return applyToChildren(a, lambda x : simplify(x))
+
+def propogateMetadata(a, argTypes, variableMap, idNum):
+ """This function propogates metadata about type throughout the AST.
+ argTypes lets us institute types for each function's args
+ variableMap maps variable ids to their types and id numbers
+ idNum gives us a global number to use for variable ids"""
+ if not isinstance(a, ast.AST):
+ return a
+ elif type(a) == ast.FunctionDef:
+ if a.name in argTypes:
+ theseTypes = argTypes[a.name]
+ else:
+ theseTypes = []
+ variableMap = copy.deepcopy(variableMap) # variables shouldn't affect variables outside
+ idNum = copy.deepcopy(idNum)
+ for i in range(len(a.args.args)):
+ arg = a.args.args[i]
+ if type(arg) == ast.arg:
+ simplifyUpdateId(arg, variableMap, idNum)
+ # Match the args if possible
+ if len(a.args.args) == len(theseTypes) and \
+ theseTypes[i] in ["int", "str", "float", "bool", "list"]:
+ arg.type = eval(theseTypes[i])
+ variableMap[arg.arg] = (arg.type, arg.varID)
+ else:
+ arg.type = None
+ variableMap[arg.arg] = (None, arg.varID)
+ else:
+ log("transformations\tpropogateMetadata\tWeird type in args: " + str(type(arg)), "bug")
+
+ newBody = []
+ for line in a.body:
+ newLine = propogateMetadata(line, argTypes, variableMap, idNum)
+ if type(newLine) == list:
+ newBody += newLine
+ else:
+ newBody.append(newLine)
+ a.body = newBody
+ return a
+ elif type(a) == ast.Assign:
+ val = a.value = propogateMetadata(a.value, argTypes, variableMap, idNum)
+ if len(a.targets) == 1:
+ t = a.targets[0]
+ if type(t) == ast.Name:
+ simplifyUpdateId(t, variableMap, idNum)
+ varType = eventualType(val)
+ t.type = varType
+ variableMap[t.id] = (varType, t.varID) # update type
+ elif type(t) in [ast.Tuple, ast.List]:
+ # If the items are being assigned separately, with no dependance on each other,
+ # assign the appropriate types
+ getTypes = type(val) in [ast.Tuple, ast.List] and len(t.elts) == len(val.elts) and len(allVariableNamesUsed(val)) == 0
+ for j in range(len(t.elts)):
+ t2 = t.elts[j]
+ givenType = eventualType(val.elts[j]) if getTypes else None
+ if type(t2) == ast.Name:
+ simplifyUpdateId(t2, variableMap, idNum)
+ t2.type = givenType
+ variableMap[t2.id] = (givenType, t2.varID)
+ elif type(t.elts[j]) in [ast.Subscript, ast.Attribute, ast.Tuple, ast.List]:
+ pass
+ else:
+ log("transformations\tpropogateMetadata\tOdd listTarget: " + str(type(t.elts[j])), "bug")
+ return a
+ elif type(a) == ast.AugAssign:
+ # Turn all AugAssigns into Assigns
+ t = a.target = propogateMetadata(a.target, argTypes, variableMap, idNum)
+ a.value = propogateMetadata(a.value, argTypes, variableMap, idNum)
+ if type(t) == ast.Name:
+ if eventualType(a.target) not in [bool, int, str, float]:
+ finalType = None
+ else:
+ if type(a.target) == ast.Name:
+ loadedTarget = ast.Name(a.target.id, ast.Load())
+ elif type(a.target) == ast.Subscript:
+ loadedTarget = ast.Subscript(deepcopy(a.target.value), deepcopy(a.target.slice), ast.Load())
+ elif type(a.target) == ast.Attribute:
+ loadedTarget = ast.Attribute(deepcopy(a.target.value), a.target.attr, ast.Load())
+ elif type(a.target) == ast.Tuple:
+ loadedTarget = ast.Tuple(deepcopy(a.target.elts), ast.Load())
+ elif type(a.target) == ast.List:
+ loadedTarget = ast.List(deepcopy(a.target.elts), ast.Load())
+ else:
+ log("transformations\tsimplify\tOdd AugAssign target: " + str(type(a.target)), "bug")
+ transferMetaData(a.target, loadedTarget)
+ actualValue = ast.BinOp(loadedTarget, a.op, a.value)
+ finalType = eventualType(actualValue)
+
+ simplifyUpdateId(t, variableMap, idNum)
+ varType = finalType
+ t.type = varType
+ variableMap[t.id] = (varType, t.varID) # update type
+ return a
+ elif type(a) == ast.For: # START HERE
+ a.iter = propogateMetadata(a.iter, argTypes, variableMap, idNum)
+ if type(a.target) == ast.Name:
+ simplifyUpdateId(a.target, variableMap, idNum)
+ ev = eventualType(a.iter)
+ if ev == str:
+ a.target.type = str
+ # we know ranges are made of ints
+ elif type(a.iter) == ast.Call and type(a.iter.func) == ast.Name and (a.iter.func.id) == "range":
+ a.target.type = int
+ else:
+ a.target.type = None
+ variableMap[a.target.id] = (a.target.type, a.target.varID)
+ a.target = propogateMetadata(a.target, argTypes, variableMap, idNum)
+
+ # Go through the body and orelse to map out more variables
+ body = []
+ bodyMap = copy.deepcopy(variableMap)
+ for i in range(len(a.body)):
+ result = propogateMetadata(a.body[i], argTypes, bodyMap, idNum)
+ if type(result) == list:
+ body += result
+ else:
+ body.append(result)
+ a.body = body
+
+ orelse = []
+ orelseMap = copy.deepcopy(variableMap)
+ for var in bodyMap:
+ if var not in orelseMap:
+ orelseMap[var] = (42, bodyMap[var][1])
+ for i in range(len(a.orelse)):
+ result = propogateMetadata(a.orelse[i], argTypes, orelseMap, idNum)
+ if type(result) == list:
+ orelse += result
+ else:
+ orelse.append(result)
+ a.orelse = orelse
+
+ keys = list(variableMap.keys())
+ for key in keys: # reset types of changed keys
+ if key not in bodyMap or bodyMap[key] != variableMap[key] or \
+ key not in orelseMap or orelseMap[key] != variableMap[key]:
+ variableMap[key] = (None, variableMap[key][1])
+ if countOccurances(a.body, ast.Break) == 0: # We will definitely enter the else!
+ for key in orelseMap:
+ # If we KNOW it will be this type
+ if key not in bodyMap or bodyMap[key] == orelseMap[key]:
+ variableMap[key] = orelseMap[key]
+
+ # If we KNOW it will run at least once
+ if listNotEmpty(a.iter):
+ for key in bodyMap:
+ if key in variableMap:
+ continue
+ # type might be changed
+ elif key in orelseMap and orelseMap[key] != bodyMap[key]:
+ continue
+ variableMap[key] = bodyMap[key]
+ return a
+ elif type(a) == ast.While:
+ body = []
+ bodyMap = copy.deepcopy(variableMap)
+ for i in range(len(a.body)):
+ result = propogateMetadata(a.body[i], argTypes, bodyMap, idNum)
+ if type(result) == list:
+ body += result
+ else:
+ body.append(result)
+ a.body = body
+
+ orelse = []
+ orelseMap = copy.deepcopy(variableMap)
+ for var in bodyMap:
+ if var not in orelseMap:
+ orelseMap[var] = (None, bodyMap[var][1])
+ for i in range(len(a.orelse)):
+ result = propogateMetadata(a.orelse[i], argTypes, orelseMap, idNum)
+ if type(result) == list:
+ orelse += result
+ else:
+ orelse.append(result)
+ a.orelse = orelse
+
+ keys = list(variableMap.keys())
+ for key in keys:
+ if key not in bodyMap or bodyMap[key] != variableMap[key] or \
+ key not in orelseMap or orelseMap[key] != variableMap[key]:
+ variableMap[key] = (None, variableMap[key][1])
+
+ a.test = propogateMetadata(a.test, argTypes, variableMap, idNum)
+ return a
+ elif type(a) == ast.If:
+ a.test = propogateMetadata(a.test, argTypes, variableMap, idNum)
+ variableMap1 = copy.deepcopy(variableMap)
+ variableMap2 = copy.deepcopy(variableMap)
+
+ body = []
+ for i in range(len(a.body)):
+ result = propogateMetadata(a.body[i], argTypes, variableMap1, idNum)
+ if type(result) == list:
+ body += result
+ else:
+ body.append(result)
+ a.body = body
+
+ for var in variableMap1:
+ if var not in variableMap2:
+ variableMap2[var] = (42, variableMap1[var][1])
+
+ orelse = []
+ for i in range(len(a.orelse)):
+ result = propogateMetadata(a.orelse[i], argTypes, variableMap2, idNum)
+ if type(result) == list:
+ orelse += result
+ else:
+ orelse.append(result)
+ a.orelse = orelse
+
+ variableMap.clear()
+ for key in variableMap1:
+ if key in variableMap2:
+ if variableMap1[key] == variableMap2[key]:
+ variableMap[key] = variableMap1[key]
+ elif variableMap1[key][1] != variableMap2[key][1]:
+ log("transformations\tsimplify\tvarId mismatch", "bug")
+ else: # type mismatch
+ if variableMap2[key][0] == 42:
+ variableMap[key] = variableMap1[key]
+ else:
+ variableMap[key] = (None, variableMap1[key][1])
+ elif len(a.orelse) > 0 and type(a.orelse[-1]) == ast.Return: # if the else exits, it doesn't matter
+ variableMap[key] = variableMap1[key]
+ for key in variableMap2:
+ if key not in variableMap1 and len(a.body) > 0 and type(a.body[-1]) == ast.Return: # if the if exits, it doesn't matter
+ variableMap[key] = variableMap2[key]
+ return a
+ elif type(a) == ast.Name:
+ if a.id in variableMap and variableMap[a.id][0] != 42:
+ a.type = variableMap[a.id][0]
+ a.varID = variableMap[a.id][1]
+ else:
+ if a.id in variableMap:
+ a.varID = variableMap[a.id][1]
+ a.ctx = propogateMetadata(a.ctx, argTypes, variableMap, idNum)
+ return a
+ elif type(a) == ast.AugLoad:
+ return ast.Load()
+ elif type(a) == ast.AugStore:
+ return ast.Store()
+ return applyToChildren(a, lambda x : propogateMetadata(x, argTypes, variableMap, idNum))
+
+### SIMPLIFYING FUNCTIONS ###
+
+def applyTransferLambda(x):
+ """Simplify an expression by applying constant folding, re-formatting to an AST, and then tranferring the metadata appropriately."""
+ if x == None:
+ return x
+ tmp = astFormat(constantFolding(x))
+ if hasattr(tmp, "global_id") and hasattr(x, "global_id") and tmp.global_id != x.global_id:
+ return tmp # don't do the transfer, this already has its own metadata
+ else:
+ transferMetaData(x, tmp)
+ return tmp
+
+def constantFolding(a):
+ """In constant folding, we evaluate all constant expressions instead of doing operations at runtime"""
+ if not isinstance(a, ast.AST):
+ return a
+ t = type(a)
+ if t in [ast.FunctionDef, ast.ClassDef]:
+ for i in range(len(a.body)):
+ a.body[i] = applyTransferLambda(a.body[i])
+ return a
+ elif t in [ast.Import, ast.ImportFrom, ast.Global]:
+ return a
+ elif t == ast.BoolOp:
+ # Condense the boolean's values
+ newValues = []
+ ranks = []
+ count = 0
+ for val in a.values:
+ # Condense the boolean operations into one line, if possible
+ c = constantFolding(val)
+ if type(c) == ast.BoolOp and type(c.op) == type(a.op) and not hasattr(c, "multiComp"):
+ newValues += c.values
+ ranks.append(range(count,count+len(c.values)))
+ count += len(c.values)
+ else:
+ newValues.append(c)
+ ranks.append(count)
+ count += 1
+
+ # Or breaks with True, And breaks with False
+ breaks = (type(a.op) == ast.Or)
+
+ # Remove the opposite values IF removing them won't mess up the type.
+ i = len(newValues) - 1
+ while i > 0:
+ if (newValues[i] == (not breaks)) and eventualType(newValues[i-1]) == bool:
+ newValues.pop(i)
+ i -= 1
+
+ if len(newValues) == 0:
+ # There's nothing to evaluate
+ return (not breaks)
+ elif len(newValues) == 1:
+ # If we're down to one value, just return it!
+ return newValues[0]
+ elif newValues[0] == breaks:
+ # If the first value breaks it, done!
+ return breaks
+ elif newValues.count(breaks) >= 1:
+ # We don't need any values that occur after a break
+ i = newValues.index(breaks)
+ newValues = newValues[:i+1]
+ for i in range(len(newValues)):
+ newValues[i] = astFormat(newValues[i])
+ # get the corresponding value
+ if i in ranks:
+ transferMetaData(a.values[ranks.index(i)], newValues[i])
+ else: # it's in a list
+ for j in range(len(ranks)):
+ if type(ranks[j]) == list and i in ranks[j]:
+ transferMetaData(a.values[j].values[ranks[j].index(i)], newValues[i])
+ break
+ a.values = newValues
+ return a
+ elif t == ast.BinOp:
+ l = constantFolding(a.left)
+ r = constantFolding(a.right)
+ # Hack to make hint chaining work- don't constant-fold filler strings!
+ if containsTokenStepString(l) or containsTokenStepString(r):
+ a.left = applyTransferLambda(a.left)
+ a.right = applyTransferLambda(a.right)
+ return a
+ if type(l) in builtInTypes and type(r) in builtInTypes:
+ try:
+ val = doBinaryOp(a.op, l, r)
+ if type(val) == float and val % 0.0001 != 0: # don't deal with trailing floats
+ pass
+ else:
+ tmp = astFormat(val)
+ transferMetaData(a, tmp)
+ return tmp
+ except:
+ # We have some kind of divide-by-zero issue.
+ # Therefore, don't calculate it!
+ pass
+ if type(l) in builtInTypes:
+ if type(r) == bool:
+ r = int(r)
+ # Commutative operations
+ elif type(r) == ast.BinOp and type(r.op) == type(a.op) and type(a.op) in [ast.Add, ast.Mult, ast.BitOr, ast.BitAnd, ast.BitXor]:
+ rLeft = constantFolding(r.left)
+ if type(rLeft) in builtInTypes:
+ try:
+ newLeft = astFormat(doBinaryOp(a.op, l, rLeft))
+ transferMetaData(r.left, newLeft)
+ return ast.BinOp(newLeft, a.op, r.right)
+ except Exception as e:
+ pass
+
+ # Empty string is often unneccessary
+ if type(l) == str and l == '':
+ if type(a.op) == ast.Add and eventualType(r) == str:
+ return r
+ elif type(a.op) == ast.Mult and eventualType(r) == int:
+ return ''
+ elif type(l) == bool:
+ l = int(l)
+ # 0 is often unneccessary
+ if l == 0 and eventualType(r) in [int, float]:
+ if type(a.op) in [ast.Add, ast.BitOr]:
+ # If it won't change the type
+ if type(l) == int or eventualType(r) == float:
+ return r
+ elif type(l) == float: # Cast it
+ return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [r], [])
+ elif type(a.op) == ast.Sub:
+ tmpR = astFormat(r)
+ transferMetaData(a.right, tmpR)
+ newR = ast.UnaryOp(ast.USub(addedOtherOp=True), tmpR, addedOther=True)
+ if type(l) == int or eventualType(r) == float:
+ return newR
+ elif type(l) == float:
+ return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [newR], [])
+ elif type(a.op) in [ast.Mult, ast.LShift, ast.RShift]:
+ # If either is a float, it's 0
+ return 0.0 if float in [eventualType(r), type(l)] else 0
+ elif type(a.op) in [ast.Div, ast.FloorDiv, ast.Mod]:
+ # Check if the right might be zero
+ if type(r) in builtInTypes and r != 0:
+ return 0.0 if float in [eventualType(r), type(l)] else 0
+ # Same for 1
+ elif l == 1:
+ if type(a.op) == ast.Mult and eventualType(r) in [int, float]:
+ if type(l) == int or eventualType(r) == float:
+ return r
+ elif type(l) == float:
+ return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [r], [])
+ # No reason to make this a float if the other value has already been cast
+ elif type(l) == float and l == int(l):
+ if type(a.op) in [ast.Add, ast.Sub, ast.Mult, ast.Div] and eventualType(r) == float:
+ l = int(l)
+ # Some of the same operations are done with the right, but not all of them
+ if type(r) in builtInTypes:
+ if type(r) == str and r == '':
+ if type(a.op) == ast.Add and eventualType(l) == str:
+ return l
+ elif type(a.op) == ast.Mult and eventualType(l) == int:
+ return ''
+ elif type(r) == bool:
+ r = int(r)
+ else:
+ if r == 0 and eventualType(l) in [int, float]:
+ if type(a.op) in [ast.Add, ast.Sub, ast.LShift, ast.RShift, ast.BitOr]:
+ if type(r) == int or eventualType(l) == float:
+ return l
+ elif type(r) == float:
+ return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [l], [])
+ elif type(a.op) == ast.Mult:
+ return 0.0 if float in [eventualType(l), type(r)] else 0
+ elif r == 1:
+ if type(a.op) in [ast.Mult, ast.Div, ast.Pow] and eventualType(l) in [int, float]:
+ if type(r) == int or eventualType(l) == float:
+ return l
+ elif type(r) == float:
+ return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [l], [])
+ elif type(a.op) == ast.FloorDiv and eventualType(l) == int:
+ if eventualType( r ) == int:
+ return l
+ elif eventualType( r ) == float:
+ return ast.Call(ast.Name("float", ast.Load(), typeCastFunction=True), [l], [])
+ elif type(r) == float and r == int(r):
+ if type(a.op) in [ast.Add, ast.Sub, ast.Mult, ast.Div] and eventualType(l) == float:
+ r = int(r)
+ a.left = applyTransferLambda(a.left)
+ a.right = applyTransferLambda(a.right)
+ return a
+ elif t == ast.IfExp:
+ # Sometimes, we can simplify the statement
+ test = constantFolding(a.test)
+ b = constantFolding(a.body)
+ o = constantFolding(a.orelse)
+
+ aTest = astFormat(test)
+ transferMetaData(a.test, aTest)
+ aB = astFormat(b)
+ transferMetaData(a.body, aB)
+ aO = astFormat(o)
+ transferMetaData(a.orelse, aO)
+
+ if type(test) == bool:
+ return aB if test else aO # evaluate the if expression now
+ elif compareASTs(b, o, checkEquality=True) == 0:
+ return aB # if they're the same, no reason for the expression
+ a.test = aTest
+ a.body = aB
+ a.orelse = aO
+ return a
+ elif t == ast.Compare:
+ if len(a.ops) == 0 or len(a.comparators) == 0:
+ return True # No ops? Okay, empty case is true!
+ op = a.ops[0]
+ l = constantFolding(a.left)
+ r = constantFolding(a.comparators[0])
+ # Hack to make hint chaining work- don't constant-fold filler strings!
+ if containsTokenStepString(l) or containsTokenStepString(r):
+ tmpLeft = astFormat(l)
+ transferMetaData(a.left, tmpLeft)
+ a.left = tmpLeft
+ tmpRight = astFormat(r)
+ transferMetaData(a.comparators[0], tmpRight)
+ a.comparators = [tmpRight]
+ return a
+ # Check whether the two sides are the same
+ comp = compareASTs(l, r, checkEquality=True) == 0
+ if comp and (not couldCrash(l)) and type(op) in [ast.Lt, ast.Gt, ast.NotEq]:
+ tmp = ast.NameConstant(False)
+ transferMetaData(a, tmp)
+ return tmp
+ elif comp and (not couldCrash(l)) and type(op) in [ast.Eq, ast.LtE, ast.GtE]:
+ tmp = ast.NameConstant(True)
+ transferMetaData(a, tmp)
+ return tmp
+ if (type(l) in builtInTypes) and (type(r) in builtInTypes):
+ try:
+ result = astFormat(doCompare(op, l, r))
+ transferMetaData(a, result)
+ return result
+ except:
+ pass
+ # Reduce the expressions when possible!
+ if type(l) == type(r) == ast.BinOp and type(l.op) == type(r.op) and not couldCrash(l) and not couldCrash(r):
+ if type(l.op) == ast.Add:
+ # Remove repeated values
+ unchanged = False
+ if compareASTs(l.left, r.left, checkEquality=True) == 0:
+ l = l.right
+ r = r.right
+ elif compareASTs(l.right, r.right, checkEquality=True) == 0:
+ l = l.left
+ r = r.left
+ elif compareASTs(l.left, r.right, checkEquality=True) == 0 and eventualType(l) in [int, float]:
+ l = l.right
+ r = r.left
+ elif compareASTs(l.right, r.left, checkEquality=True) == 0 and eventualType(l) in [int, float]:
+ l = l.left
+ r = r.right
+ else:
+ unchanged = True
+ if not unchanged:
+ tmpLeft = astFormat(l)
+ transferMetaData(a.left, tmpLeft)
+ a.left = tmpLeft
+ tmpRight = astFormat(r)
+ transferMetaData(a.comparators[0], tmpRight)
+ a.comparators = [tmpRight]
+ return constantFolding(a) # Repeat this check to see if we can keep reducing it
+ elif type(l.op) == ast.Sub:
+ unchanged = False
+ if compareASTs(l.left, r.left, checkEquality=True) == 0:
+ l = l.right
+ r = r.right
+ elif compareASTs(l.right, r.right, checkEquality=True) == 0:
+ l = l.left
+ r = r.left
+ else:
+ unchanged = True
+ if not unchanged:
+ tmpLeft = astFormat(l)
+ transferMetaData(a.left, tmpLeft)
+ a.left = tmpLeft
+ tmpRight = astFormat(r)
+ transferMetaData(a.comparators[0], tmpRight)
+ a.comparators = [tmpRight]
+ return constantFolding(a)
+ tmpLeft = astFormat(l)
+ transferMetaData(a.left, tmpLeft)
+ a.left = tmpLeft
+ tmpRight = astFormat(r)
+ transferMetaData(a.comparators[0], tmpRight)
+ a.comparators = [tmpRight]
+ return a
+ elif t == ast.Call:
+ # TODO: this can be done much better
+ a.func = applyTransferLambda(a.func)
+
+ allConstant = True
+ tmpArgs = []
+ for i in range(len(a.args)):
+ tmpArgs.append(constantFolding(a.args[i]))
+ if type(tmpArgs[i]) not in [int, float, bool, str]:
+ allConstant = False
+ if len(a.keywords) > 0:
+ allConstant = False
+ if allConstant and (type(a.func) == ast.Name) and (a.func.id in builtInFunctions.keys()) and \
+ (a.func.id not in ["range", "raw_input", "input", "open", "randint", "random", "slice"]):
+ try:
+ result = apply(eval(a.func.id), tmpArgs)
+ transferMetaData(a, astFormat(result))
+ return result
+ except:
+ # Not gonna happen unless it crashes
+ #log("transformations\tconstantFolding\tFunction crashed: " + str(a.func.id), "bug")
+ pass
+ for i in range(len(a.args)):
+ tmpArg = astFormat(tmpArgs[i])
+ transferMetaData(a.args[i], tmpArg)
+ a.args[i] = tmpArg
+ return a
+ # This needs to be separate because the attribute is a string
+ elif t == ast.Attribute:
+ a.value = applyTransferLambda(a.value)
+ return a
+ elif t == ast.Slice:
+ if a.lower != None:
+ a.lower = applyTransferLambda(a.lower)
+ if a.upper != None:
+ a.upper = applyTransferLambda(a.upper)
+ if a.step != None:
+ a.step = applyTransferLambda(a.step)
+ return a
+ elif t == ast.Num:
+ return a.n
+ elif t == ast.Bytes:
+ return a.s
+ elif t == ast.Str:
+ # Don't do things to filler strings
+ if len(a.s) > 0 and isTokenStepString(a.s):
+ return a
+ return a.s
+ elif t == ast.NameConstant:
+ if a.value == True:
+ return True
+ elif a.value == False:
+ return False
+ elif a.value == None:
+ return None
+ elif t == ast.Name:
+ return a
+ else: # All statements, ast.Lambda, ast.Dict, ast.Set, ast.Repr, ast.Attribute, ast.Subscript, etc.
+ return applyToChildren(a, applyTransferLambda)
+
+def isMutatingFunction(a):
+ """Given a function call, this checks whether it might change the program state when run"""
+ if type(a) != ast.Call: # Can only call this on Calls!
+ log("transformations\tisMutatingFunction\tNot a Call: " + str(type(a)), "bug")
+ return True
+
+ # Map of all static namesets
+ funMaps = { "math" : mathFunctions, "string" : builtInStringFunctions,
+ "str" : builtInStringFunctions, "list" : staticListFunctions,
+ "dict" : staticDictFunctions }
+ typeMaps = { str : "string", list : "list", dict : "dict" }
+ if type(a.func) == ast.Name:
+ funDict = builtInFunctions
+ funName = a.func.id
+ elif type(a.func) == ast.Attribute:
+ if type(a.func.value) == ast.Name and a.func.value.id in funMaps:
+ funDict = funMaps[a.func.value.id]
+ funName = a.func.attr
+ # if the item is calling a function directly
+ elif eventualType(a.func.value) in typeMaps:
+ funDict = funMaps[typeMaps[eventualType(a.func.value)]]
+ funName = a.func.attr
+ else:
+ return True
+ else:
+ return True # we don't know, so yes
+
+ # TODO: deal with student's functions
+ return funName not in funDict
+
+def allVariablesUsed(a):
+ if not isinstance(a, ast.AST):
+ return []
+ elif type(a) == ast.Name:
+ return [a]
+ variables = []
+ for child in ast.iter_child_nodes(a):
+ variables += allVariablesUsed(child)
+ return variables
+
+def allVariableNamesUsed(a):
+ """Gathers all the variable names used in the ast"""
+ if not isinstance(a, ast.AST):
+ return []
+ elif type(a) == ast.Name:
+ return [a.id]
+ elif type(a) == ast.Assign:
+ """In assignments, ignore all pure names used- they're being assigned to, not used"""
+ variables = allVariableNamesUsed(a.value)
+ for target in a.targets:
+ if type(target) == ast.Name:
+ pass
+ elif type(target) in [ast.Tuple, ast.List]:
+ for elt in target.elts:
+ if type(elt) != ast.Name:
+ variables += allVariableNamesUsed(elt)
+ else:
+ variables += allVariableNamesUsed(target)
+ return variables
+ elif type(a) == ast.AugAssign:
+ variables = allVariableNamesUsed(a.value)
+ variables += allVariableNamesUsed(a.target)
+ return variables
+ variables = []
+ for child in ast.iter_child_nodes(a):
+ variables += allVariableNamesUsed(child)
+ return variables
+
+def addPropTag(a, globalId):
+ if not isinstance(a, ast.AST):
+ return a
+ a.propagatedVariable = True
+ if hasattr(a, "global_id"):
+ a.variableGlobalId = globalId
+ return applyToChildren(a, lambda x : addPropTag(x, globalId))
+
+def propagateValues(a, liveVars):
+ """Propagate the given values through the AST whenever their variables occur"""
+ if ((not isinstance(a, ast.AST) or len(liveVars.keys()) == 0)):
+ return a
+
+ if type(a) == ast.Name:
+ # Propagate the value if we have it!
+ if a.id in liveVars:
+ val = copy.deepcopy(liveVars[a.id])
+ val.loadedVariable = True
+ if hasattr(a, "global_id"):
+ val.variableGlobalId = a.global_id
+ return applyToChildren(val, lambda x : addPropTag(x, a.global_id))
+ else:
+ return a
+ elif type(a) == ast.Call:
+ # If something is mutated, it cannot be propagated anymore
+ if isMutatingFunction(a):
+ allVars = allVariablesUsed(a)
+ for var in allVars:
+ if (eventualType(var) not in [int, float, bool, str]):
+ if (var.id in liveVars):
+ del liveVars[var.id]
+ currentLiveVars = list(liveVars.keys())
+ for liveVar in currentLiveVars:
+ varsWithin = allVariableNamesUsed(liveVars[liveVar])
+ if var.id in varsWithin:
+ del liveVars[liveVar]
+ return a
+ elif type(a.func) == ast.Name and a.func.id in liveVars and \
+ eventualType(liveVars[a.func.id]) in [int, float, complex, bytes, bool, type(None)]:
+ # Special case: don't move a simple value to the front of a Call
+ # because it will cause a compiler error instead of a runtime error
+ a.args = propagateValues(a.args, liveVars)
+ a.keywords = propagateValues(a.keywords, liveVars)
+ return a
+ elif type(a) == ast.Attribute:
+ if type(a.value) == ast.Name and a.value.id in liveVars and \
+ eventualType(liveVars[a.value.id]) in [int, float, complex, bytes, bool, type(None)]:
+ # Don't move for the same reason as above
+ return a
+ return applyToChildren(a, lambda x: propagateValues(x, liveVars))
+
+def hasMutatingFunction(a):
+ """Checks to see if the ast has any potentially mutating functions"""
+ if not isinstance(a, ast.AST):
+ return False
+ for node in ast.walk(a):
+ if type(a) == ast.Call:
+ if isMutatingFunction(a):
+ return True
+ return False
+
+def clearBlockVars(a, liveVars):
+ """Clear all the vars set in this block out of the live vars"""
+ if (not isinstance(a, ast.AST)) or len(liveVars.keys()) == 0:
+ return
+
+ if type(a) in [ast.Assign, ast.AugAssign]:
+ if type(a) == ast.Assign:
+ targets = gatherAssignedVars(a.targets)
+ else:
+ targets = gatherAssignedVars([a.target])
+ for target in targets:
+ varId = None
+ if type(target) == ast.Name:
+ varId = target.id
+ elif type(target.value) == ast.Name:
+ varId = target.value.id
+ if varId in liveVars:
+ del liveVars[varId]
+
+ liveKeys = list(liveVars.keys())
+ for var in liveKeys:
+ # Remove the variable and any variables in which it is used
+ if varId in allVariableNamesUsed(liveVars[var]):
+ del liveVars[var]
+ return
+ elif type(a) == ast.Call:
+ if hasMutatingFunction(a):
+ for v in allVariablesUsed(a):
+ if eventualType(v) not in [int, float, bool, str]:
+ if v.id in liveVars:
+ del liveVars[v.id]
+ liveKeys = list(liveVars.keys())
+ for var in liveKeys:
+ if v.id in allVariableNamesUsed(liveVars[var]):
+ del liveVars[var]
+ return
+ elif type(a) == ast.For:
+ names = []
+ if type(a.target) == ast.Name:
+ names = [a.target.id]
+ elif type(a.target) in [ast.Tuple, ast.List]:
+ for elt in a.target.elts:
+ if type(elt) == ast.Name:
+ names.append(elt.id)
+ elif type(elt) == ast.Subscript:
+ if type(elt.value) == ast.Name:
+ names.append(elt.value.id)
+ else:
+ log("transformations\tclearBlockVars\tFor target subscript not a name: " + str(type(elt.value)), "bug")
+ else:
+ log("transformations\tclearBlockVars\tFor target not a name: " + str(type(elt)), "bug")
+ elif type(a.target) == ast.Subscript:
+ if type(a.target.value) == ast.Name:
+ names.append(a.target.value.id)
+ else:
+ log("transformations\tclearBlockVars\tFor target subscript not a name: " + str(type(a.target.value)), "bug")
+ else:
+ log("transformations\tclearBlockVars\tFor target not a name: " + str(type(a.target)), "bug")
+ for name in names:
+ if name in liveVars:
+ del liveVars[name]
+
+ liveKeys = list(liveVars.keys())
+ for var in liveKeys:
+ # Remove the variable and any variables in which it is used
+ if name in allVariableNamesUsed(liveVars[var]):
+ del liveVars[var]
+
+ for child in ast.iter_child_nodes(a):
+ clearBlockVars(child, liveVars)
+
+def copyPropagation(a, liveVars=None, inLoop=False):
+ """Propagate variables into the tree, when possible"""
+ if liveVars == None:
+ liveVars = { }
+ if type(a) == ast.Module:
+ a.body = copyPropagation(a.body)
+ return a
+ if type(a) == ast.FunctionDef:
+ a.body = copyPropagation(a.body, liveVars=liveVars)
+ return a
+
+ if type(a) == list:
+ i = 0
+ while i < len(a):
+ deleteLine = False
+ if type(a[i]) == ast.FunctionDef:
+ a[i].body = copyPropagation(a[i].body, liveVars=copy.deepcopy(liveVars))
+ elif type(a[i]) == ast.ClassDef:
+ # TODO: can we propagate values through everything after here?
+ for j in range(len(a[i].body)):
+ if type(a[i].body[j]) == ast.FunctionDef:
+ a[i].body[j] = copyPropagation(a[i].body[j])
+ elif type(a[i]) == ast.Assign:
+ # In assignments, propagate values into the right side and move the left side into the live vars
+ a[i].value = propagateValues(a[i].value, liveVars)
+ target = a[i].targets[0]
+
+ if type(target) in [ast.Name, ast.Subscript, ast.Attribute]:
+ varId = None
+ # In plain names, we can update the liveVars
+ if type(target) == ast.Name:
+ varId = target.id
+ if inLoop or couldCrash(a[i].value) or eventualType(a[i].value) not in [bool, int, float, str, tuple]:
+ # Remove this variable from the live vars
+ if varId in liveVars:
+ del liveVars[varId]
+ else:
+ liveVars[varId] = a[i].value
+ # For other values, we can at least clear out liveVars correctly
+ # TODO: can we expand this?
+ elif target.value == ast.Name:
+ varId = target.value.id
+
+ # Now, update the live vars based on anything reset by the new target
+ liveKeys = list(liveVars.keys())
+ for var in liveKeys:
+ # If the var we're replacing was used elsewhere, that value will no longer be the same
+ if varId in allVariableNamesUsed(liveVars[var]):
+ del liveVars[var]
+ elif type(target) in [ast.Tuple, ast.List]:
+ # Copy the values, if we can match them
+ if type(a[i].value) in [ast.Tuple, ast.List] and len(target.elts) == len(a[i].value.elts):
+ for j in range(len(target.elts)):
+ if type(target.elts[j]) == ast.Name:
+ if (not couldCrash(a[i].value.elts[j])):
+ liveVars[target.elts[j]] = a[i].value.elts[j]
+ else:
+ if target.elts[j] in liveVars:
+ del liveVars[target.elts[j]]
+
+ # Then get rid of any overwrites
+ for e in target.elts:
+ if type(e) in [ast.Name, ast.Subscript, ast.Attribute]:
+ varId = None
+ if type(e) == ast.Name:
+ varId = e.id
+ elif type(e.value) == ast.Name:
+ varId = e.value.id
+
+ liveKeys = list(liveVars.keys())
+ for var in liveKeys:
+ if varId in allVariableNamesUsed(liveVars[var]):
+ del liveVars[var]
+ else:
+ log("transformations\tcopyPropagation\tWeird assign type: " + str(type(e)), "bug")
+ elif type(a[i]) == ast.AugAssign:
+ a[i].value = propagateValues(a[i].value, liveVars)
+ assns = gatherAssignedVarIds([a[i].target])
+ for target in assns:
+ if target in liveVars:
+ del liveVars[target]
+ elif type(a[i]) == ast.For:
+ # FIRST, propagate values into the iter
+ if type(a[i].iter) != ast.Name: # if it IS a name, don't replace it!
+ # Otherwise, we propagate first since this is evaluated once
+ a[i].iter = propagateValues(a[i].iter, liveVars)
+
+ # We reset the target variable, so reset the live vars
+ names = []
+ if type(a[i].target) == ast.Name:
+ names = [a[i].target.id]
+ elif type(a[i].target) in [ast.Tuple, ast.List]:
+ for elt in a[i].target.elts:
+ if type(elt) == ast.Name:
+ names.append(elt.id)
+ elif type(elt) == ast.Subscript:
+ if type(elt.value) == ast.Name:
+ names.append(elt.value.id)
+ else:
+ log("transformations\tcopyPropagation\tFor target subscript not a name: " + str(type(elt.value)) + "\t" + printFunction(elt.value), "bug")
+ else:
+ log("transformations\tcopyPropagation\tFor target not a name: " + str(type(elt)) + "\t" + printFunction(elt), "bug")
+ elif type(a[i].target) == ast.Subscript:
+ if type(a[i].target.value) == ast.Name:
+ names.append(a[i].target.value.id)
+ else:
+ log("transformations\tcopyPropagation\tFor target subscript not a name: " + str(type(a[i].target.value)) + "\t" + printFunction(a[i].target.value), "bug")
+ else:
+ log("transformations\tcopyPropagation\tFor target not a name: " + str(type(a[i].target)) + "\t" + printFunction(a[i].target), "bug")
+
+ for name in names:
+ liveKeys = list(liveVars.keys())
+ for var in liveKeys:
+ if name in allVariableNamesUsed(liveVars[var]):
+ del liveVars[var]
+ if name in liveVars:
+ del liveVars[name]
+ clearBlockVars(a[i], liveVars)
+ a[i].body = copyPropagation(a[i].body, copy.deepcopy(liveVars), inLoop=True)
+ a[i].orelse = copyPropagation(a[i].orelse, copy.deepcopy(liveVars), inLoop=True)
+ elif type(a[i]) == ast.While:
+ clearBlockVars(a[i], liveVars)
+ a[i].test = propagateValues(a[i].test, liveVars)
+ a[i].body = copyPropagation(a[i].body, copy.deepcopy(liveVars), inLoop=True)
+ a[i].orelse = copyPropagation(a[i].orelse, copy.deepcopy(liveVars), inLoop=True)
+ elif type(a[i]) == ast.If:
+ a[i].test = propagateValues(a[i].test, liveVars)
+ liveVars1 = copy.deepcopy(liveVars)
+ liveVars2 = copy.deepcopy(liveVars)
+ a[i].body = copyPropagation(a[i].body, liveVars1)
+ a[i].orelse = copyPropagation(a[i].orelse, liveVars2)
+ liveVars.clear()
+ # We can keep any values that occur in both
+ for key in liveVars1:
+ if key in liveVars2:
+ if compareASTs(liveVars1[key], liveVars2[key], checkEquality=True) == 0:
+ liveVars[key] = liveVars1[key]
+ # TODO: think more deeply about how this should work
+ elif type(a[i]) == ast.Try:
+ a[i].body = copyPropagation(a[i].body, liveVars)
+ for handler in a[i].handlers:
+ handler.body = copyPropagation(handler.body, liveVars)
+ a[i].orelse = copyPropagation(a[i].orelse, liveVars)
+ a[i].finalbody = copyPropagation(a[i].finalbody, liveVars)
+ elif type(a[i]) == ast.With:
+ a[i].body = copyPropagation(a[i].body, liveVars)
+ # With regular statements, just propagate the values
+ elif type(a[i]) in [ast.Return, ast.Delete, ast.Raise, ast.Assert, ast.Expr]:
+ propagateValues(a[i], liveVars)
+ # Breaks and Continues mess everything up
+ elif type(a[i]) in [ast.Break, ast.Continue]:
+ break
+ # These are not affected by this function
+ elif type(a[i]) in [ast.Import, ast.ImportFrom, ast.Global, ast.Pass]:
+ pass
+ else:
+ log("transformations\tcopyPropagation\tNot implemented: " + str(type(a[i])), "bug")
+ i += 1
+ return a
+ else:
+ log("transformations\tcopyPropagation\tNot a list: " + str(type(a)), "bug")
+ return a
+
+def deadCodeRemoval(a, liveVars=None, keepPrints=True, inLoop=False):
+ """Remove any code which will not be reached or used."""
+ """LiveVars keeps track of the variables that will be necessary"""
+ if liveVars == None:
+ liveVars = set()
+ if type(a) == ast.Module:
+ # Remove functions that will be overwritten anyway
+ namesSeen = []
+ i = len(a.body) - 1
+ while i >= 0:
+ if type(a.body[i]) == ast.FunctionDef:
+ if a.body[i].name in namesSeen:
+ # SPECIAL CHECK! Actually, the function will cause the code to crash if some of the args have the same name. Don't delete it then.
+ argNames = []
+ for arg in a.body[i].args.args:
+ if arg.arg in argNames:
+ break
+ else:
+ argNames.append(arg.arg)
+ else: # only remove this if the args won't break it
+ a.body.pop(i)
+ else:
+ namesSeen.append(a.body[i].name)
+ elif type(a.body[i]) == ast.Assign:
+ namesSeen += gatherAssignedVars(a.body[i].targets)
+ i -= 1
+ liveVars |= set(namesSeen) # make sure all global names are used!
+
+ if type(a) in [ast.Module, ast.FunctionDef]:
+ if type(a) == ast.Module and len(a.body) == 0:
+ return a # just don't mess with it
+ gid = a.body[0].global_id if len(a.body) > 0 and hasattr(a.body[0], "global_id") else None
+ a.body = deadCodeRemoval(a.body, liveVars=liveVars, keepPrints=keepPrints, inLoop=inLoop)
+ if len(a.body) == 0:
+ a.body = [ast.Pass(removedLines=True)] if gid == None else [ast.Pass(removedLines=True, global_id=gid)]
+ return a
+
+ if type(a) == list:
+ i = len(a) - 1
+ while i >= 0 and len(a) > 0:
+ if i >= len(a):
+ i = len(a) - 1 # just in case
+ stmt = a[i]
+ t = type(stmt)
+ # TODO: get rid of these if they aren't live
+ if t in [ast.FunctionDef, ast.ClassDef]:
+ newLiveVars = set()
+ gid = a[i].body[0].global_id if len(a[i].body) > 0 and hasattr(a[i].body[0], "global_id") else None
+ a[i] = deadCodeRemoval(a[i], liveVars=newLiveVars, keepPrints=keepPrints, inLoop=inLoop)
+ liveVars |= newLiveVars
+ # Empty functions are useless!
+ if len(a[i].body) == 0:
+ a[i].body = [ast.Pass(removedLines=True)] if gid == None else [ast.Pass(removedLines=True, global_id=gid)]
+ elif t == ast.Return:
+ # Get rid of everything that happens after this!
+ a = a[:i+1]
+ # Replace the variables
+ liveVars.clear()
+ liveVars |= set(allVariableNamesUsed(stmt))
+ elif t in [ast.Delete, ast.Assert]:
+ # Just add all variables used
+ liveVars |= set(allVariableNamesUsed(stmt))
+ elif t == ast.Assign:
+ # Check to see if the names being assigned are in the set of live variables
+ allDead = True
+ allTargets = gatherAssignedVars(stmt.targets)
+ allNamesUsed = allVariableNamesUsed(stmt.value)
+ for target in allTargets:
+ if type(target) == ast.Name and (target.id in liveVars or target.id in allNamesUsed):
+ if target.id in liveVars:
+ liveVars.remove(target.id)
+ allDead = False
+ elif type(target) in [ast.Subscript, ast.Attribute]:
+ liveVars |= set(allVariableNamesUsed(target))
+ allDead = False
+ # Also, check if the variable itself is contained in the value, because that can crash too
+ # If none are used, we can delete this line. Otherwise, use the value's vars
+ if allDead and (not couldCrash(stmt)) and (not containsTokenStepString(stmt)):
+ a.pop(i)
+ else:
+ liveVars |= set(allVariableNamesUsed(stmt.value))
+ elif t == ast.AugAssign:
+ liveVars |= set(allVariableNamesUsed(stmt.target))
+ liveVars |= set(allVariableNamesUsed(stmt.value))
+ elif t == ast.For:
+ # If there is no use of break, there's no reason to use else with the loop,
+ # so move the lines outside and go over them separately
+ if len(stmt.orelse) > 0 and countOccurances(stmt, ast.Break) == 0:
+ lines = stmt.orelse
+ stmt.orelse = []
+ a[i:i+1] = [stmt] + lines
+ i += len(lines)
+ continue # don't subtract one
+
+ targetNames = []
+ if type(a[i].target) == ast.Name:
+ targetNames = [a[i].target.id]
+ elif type(a[i].target) in [ast.Tuple, ast.List]:
+ for elt in a[i].target.elts:
+ if type(elt) == ast.Name:
+ targetNames.append(elt.id)
+ elif type(elt) == ast.Subscript:
+ if type(elt.value) == ast.Name:
+ targetNames.append(elt.value.id)
+ else:
+ log("transformations\tdeadCodeRemoval\tFor target subscript not a name: " + str(type(elt.value)) + "\t" + printFunction(elt.value), "bug")
+ else:
+ log("transformations\tdeadCodeRemoval\tFor target not a name: " + str(type(elt)) + "\t" + printFunction(elt), "bug")
+ elif type(a[i].target) == ast.Subscript:
+ if type(a[i].target.value) == ast.Name:
+ targetNames.append(a[i].target.value.id)
+ else:
+ log("transformations\tdeadCodeRemoval\tFor target subscript not a name: " + str(type(a[i].target.value)) + "\t" + printFunction(a[i].target.value), "bug")
+ else:
+ log("transformations\tdeadCodeRemoval\tFor target not a name: " + str(type(a[i].target)) + "\t" + printFunction(a[i].target), "bug")
+
+ # We need to make ALL variables in the loop live, since they update continuously
+ liveVars |= set(allVariableNamesUsed(stmt))
+ gid = stmt.body[0].global_id if len(stmt.body) > 0 and hasattr(stmt.body[0], "global_id") else None
+ stmt.body = deadCodeRemoval(stmt.body, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=True)
+ stmt.orelse = deadCodeRemoval(stmt.orelse, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=inLoop)
+ # If the body is empty and we don't need the target, get rid of it!
+ if len(stmt.body) == 0:
+ for name in targetNames:
+ if name in liveVars:
+ stmt.body = [ast.Pass(removedLines=True)] if gid == None else [ast.Pass(removedLines=True, global_id=gid)]
+ break
+ else:
+ if couldCrash(stmt.iter) or containsTokenStepString(stmt.iter):
+ a[i] = ast.Expr(stmt.iter, collapsedExpr=True)
+ else:
+ a.pop(i)
+ if len(stmt.orelse) > 0:
+ a[i:i+1] = a[i] + stmt.orelse
+
+ # The names are wiped UPDATE - NOPE, what if we never enter the loop?
+ #for name in targetNames:
+ # liveVars.remove(name)
+ elif t == ast.While:
+ # If there is no use of break, there's no reason to use else with the loop,
+ # so move the lines outside and go over them separately
+ if len(stmt.orelse) > 0 and countOccurances(stmt, ast.Break) == 0:
+ lines = stmt.orelse
+ stmt.orelse = []
+ a[i:i+1] = [stmt] + lines
+ i += len(lines)
+ continue
+
+ # We need to make ALL variables in the loop live, since they update continuously
+ liveVars |= set(allVariableNamesUsed(stmt))
+ old_global_id = stmt.body[0].global_id
+ stmt.body = deadCodeRemoval(stmt.body, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=True)
+ stmt.orelse = deadCodeRemoval(stmt.orelse, copy.deepcopy(liveVars), keepPrints=keepPrints, inLoop=inLoop)
+ # If the body is empty, get rid of it!
+ if len(stmt.body) == 0:
+ stmt.body = [ast.Pass(removedLines=True, global_id=old_global_id)]
+ elif t == ast.If:
+ # First, if True/False, just replace it with the lines
+ test = a[i].test
+ if type(test) == ast.NameConstant and test.value in [True, False]:
+ assignedVars = getAllAssignedVars(a[i])
+ for var in assignedVars:
+ # UNLESS we have a weird variable assignment problem
+ if var.id[0] == "g" and hasattr(var, "originalId"):
+ log("canonicalize\tdeadCodeRemoval\tWeird global variable: " + printFunction(a[i]), "bug")
+ break
+ else:
+ if test.value == True:
+ a[i:i+1] = a[i].body
+ else:
+ a[i:i+1] = a[i].orelse
+ continue
+ # For if statements, see if you can shorten things
+ liveVars1 = copy.deepcopy(liveVars)
+ liveVars2 = copy.deepcopy(liveVars)
+ stmt.body = deadCodeRemoval(stmt.body, liveVars1, keepPrints=keepPrints, inLoop=inLoop)
+ stmt.orelse = deadCodeRemoval(stmt.orelse, liveVars2, keepPrints=keepPrints, inLoop=inLoop)
+ liveVars.clear()
+ allVars = liveVars1 | liveVars2 | set(allVariableNamesUsed(stmt.test))
+ liveVars |= allVars
+ if len(stmt.body) == 0 and len(stmt.orelse) == 0:
+ # Get rid of the if and keep going
+ if couldCrash(stmt.test) or containsTokenStepString(stmt.test):
+ newStmt = ast.Expr(stmt.test, collapsedExpr=True)
+ transferMetaData(stmt, newStmt)
+ a[i] = newStmt
+ else:
+ a.pop(i)
+ i -= 1
+ continue
+ if len(stmt.body) == 0:
+ # If the body is empty, switch it with the else
+ stmt.test = deMorganize(ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True))
+ (stmt.body, stmt.orelse) = (stmt.orelse, stmt.body)
+ if len(stmt.orelse) == 0:
+ # See if we can make the rest of the function the else statement
+ if type(stmt.body[-1]) == type(a[-1]) == ast.Return:
+ # If the if is larger than the rest, switch them!
+ if len(stmt.body) > len(a[i+1:]):
+ stmt.test = deMorganize(ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True))
+ (a[i+1:], stmt.body) = (stmt.body, a[i+1:])
+ else:
+ # Check to see if we should switch the if and else parts
+ if len(stmt.body) > len(stmt.orelse):
+ stmt.test = deMorganize(ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True))
+ (stmt.body, stmt.orelse) = (stmt.orelse, stmt.body)
+ elif t == ast.Import:
+ if len(stmt.names) == 0:
+ a.pop(i)
+ elif t == ast.Global:
+ j = 0
+ while j < len(a.names):
+ if a.names[j] not in liveVars:
+ a.names.pop(j)
+ else:
+ j += 1
+ elif t == ast.Expr:
+ # Remove the line if it won't crash things.
+ if couldCrash(stmt) or containsTokenStepString(stmt):
+ liveVars |= set(allVariableNamesUsed(stmt))
+ else:
+ # check whether any of these variables might crash the program
+ # I know, it's weird, but occasionally a student might use a var before defining it
+ allVars = allVariableNamesUsed(stmt)
+ for j in range(i):
+ if type(a[j]) == ast.Assign:
+ for id in gatherAssignedVarIds(a[j].targets):
+ if id in allVars:
+ allVars.remove(id)
+ if len(allVars) > 0:
+ liveVars |= set(allVariableNamesUsed(stmt))
+ else:
+ a.pop(i)
+ # for now, just be careful with these types of statements
+ elif t in [ast.With, ast.Raise, ast.Try]:
+ liveVars |= set(allVariableNamesUsed(stmt))
+ elif t == ast.Pass:
+ a.pop(i) # pass does *nothing*
+ elif t in [ast.Continue, ast.Break]:
+ if inLoop: # If we're in a loop, nothing that follows matters! Otherwise, leave it alone, this will just crash.
+ a = a[:i+1]
+ break
+ # We don't know what they're doing- leave it alone
+ elif t in [ast.ImportFrom]:
+ pass
+ # Have not yet implemented these
+ else:
+ log("transformations\tdeadCodeRemoval\tNot implemented: " + str(type(stmt)), "bug")
+ i -= 1
+ return a
+ else:
+ log("transformations\tdeadCodeRemoval\tNot a list: " + str(a), "bug")
+ return a
+
+### ORDERING FUNCTIONS ###
+
+def getKeyDict(d, key):
+ if key not in d:
+ d[key] = { "self" : key }
+ return d[key]
+
+def traverseTrail(d, trail):
+ temp = d
+ for key in trail:
+ temp = temp[key]
+ return temp
+
+def areDisjoint(a, b):
+ """Are the sets of values that satisfy these two boolean constraints disjoint?"""
+ # The easiest way to be disjoint is to have comparisons that cover different areas
+ if type(a) == type(b) == ast.Compare:
+ aop = a.ops[0]
+ bop = b.ops[0]
+ aLeft = a.left
+ aRight = a.comparators[0]
+ bLeft = b.left
+ bRight = b.comparators[0]
+ alblComp = compareASTs(aLeft, bLeft, checkEquality=True)
+ albrComp = compareASTs(aLeft, bRight, checkEquality=True)
+ arblComp = compareASTs(aRight, bLeft, checkEquality=True)
+ arbrComp = compareASTs(aRight, bRight, checkEquality=True)
+ altype = type(aLeft) in [ast.Num, ast.Str]
+ artype = type(aRight) in [ast.Num, ast.Str]
+ bltype = type(bLeft) in [ast.Num, ast.Str]
+ brtype = type(bRight) in [ast.Num, ast.Str]
+
+ if (type(aop) == ast.Eq and type(bop) == ast.NotEq) or \
+ (type(bop) == ast.Eq and type(aop) == ast.NotEq):
+ # x == y, x != y
+ if (alblComp == 0 and arbrComp == 0) or (albrComp == 0 and arblComp == 0):
+ return True
+ elif type(aop) == type(bop) == ast.Eq:
+ if (alblComp == 0 and arbrComp == 0) or (albrComp == 0 and arblComp == 0):
+ return False
+ # x = num1, x = num2
+ elif alblComp == 0 and artype and brtype:
+ return True
+ elif albrComp == 0 and artype and bltype:
+ return True
+ elif arblComp == 0 and altype and brtype:
+ return True
+ elif arbrComp == 0 and altype and bltype:
+ return True
+ elif (type(aop) == ast.Lt and type(bop) == ast.GtE) or \
+ (type(aop) == ast.Gt and type(bop) == ast.LtE) or \
+ (type(aop) == ast.LtE and type(bop) == ast.Gt) or \
+ (type(aop) == ast.GtE and type(bop) == ast.Lt) or \
+ (type(aop) == ast.Is and type(bop) == ast.IsNot) or \
+ (type(aop) == ast.IsNot and type(bop) == ast.Is) or \
+ (type(aop) == ast.In and type(bop) == ast.NotIn) or \
+ (type(aop) == ast.NotIn and type(bop) == ast.In):
+ if alblComp == 0 and arbrComp == 0:
+ return True
+ elif (type(aop) == ast.Lt and type(bop) == ast.LtE) or \
+ (type(aop) == ast.Gt and type(bop) == ast.GtE) or \
+ (type(aop) == ast.LtE and type(bop) == ast.Lt) or \
+ (type(aop) == ast.GtE and type(bop) == ast.Gt):
+ if albrComp == 0 and arblComp == 0:
+ return True
+ elif type(a) == type(b) == ast.BoolOp:
+ return False # for now- TODO: when is this not true?
+ elif type(a) == ast.UnaryOp and type(a.op) == ast.Not:
+ if compareASTs(a.operand, b, checkEquality=True) == 0:
+ return True
+ elif type(b) == ast.UnaryOp and type(b.op) == ast.Not:
+ if compareASTs(b.operand, a, checkEquality=True) == 0:
+ return True
+ return False
+
+def crashesOn(a):
+ """Determines where the expression might crash"""
+ # TODO: integrate typeCrashes
+ if not isinstance(a, ast.AST):
+ return []
+ if type(a) == ast.BinOp:
+ l = eventualType(a.left)
+ r = eventualType(a.right)
+ if type(a.op) == ast.Add:
+ if not ((l == r == str) or (l in [int, float] and r in [int, float])):
+ return [a]
+ elif type(a.op) == ast.Mult:
+ if not ((l == str and r == int) or (l == int and r == str) or \
+ (l in [int, float] and r in [int, float])):
+ return [a]
+ elif type(a.op) in [ast.Sub, ast.Pow, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd]:
+ if l not in [int, float] or r not in [int, float]:
+ return [a]
+ else: # ast.Div, ast.FloorDiv, ast.Mod
+ if (type(a.right) != ast.Num or a.right.n == 0) or \
+ (l not in [int, float] or r not in [int, float]):
+ return [a]
+ elif type(a) == ast.UnaryOp:
+ if type(a.op) in [ast.UAdd, ast.USub]:
+ if eventualType(a.operand) not in [int, float]:
+ return [a]
+ elif type(a.op) == ast.Invert:
+ if eventualType(a.operand) != int:
+ return [a]
+ elif type(a) == ast.Compare:
+ if len(a.ops) != len(a.comparators):
+ return [a]
+ elif type(a.ops[0]) in [ast.In, ast.NotIn] and not isIterableType(eventualType(a.comparators[0])):
+ return [a]
+ elif type(a.ops[0]) in [ast.Lt, ast.LtE, ast.Gt, ast.GtE]:
+ # In Python3, you can't compare different types. BOOOOOO!!
+ firstType = eventualType(a.left)
+ if firstType == None:
+ return [a]
+ for comp in a.comparators:
+ if eventualType(comp) != firstType:
+ return [a]
+ elif type(a) == ast.Call:
+ env = [] # TODO: what if the environments aren't imported?
+ funMaps = { "math" : mathFunctions, "string" : builtInStringFunctions }
+ safeFunMaps = { "math" : safeMathFunctions, "string" : safeStringFunctions }
+ if type(a.func) == ast.Name:
+ funDict = builtInFunctions
+ safeFuns = builtInSafeFunctions
+ funName = a.func.id
+ elif type(a.func) == ast.Attribute:
+ if type(a.func.value) == ast.Name and a.func.value.id in funMaps:
+ funDict = funMaps[a.func.value.id]
+ safeFuns = safeFunMaps[a.func.value.id]
+ funName = a.func.attr
+ elif eventualType(a.func.value) == str:
+ funDict = funMaps["string"]
+ safeFuns = safeFunMaps["string"]
+ funName = a.func.attr
+ else: # including list and dict
+ return [a]
+ else:
+ return [a]
+
+ runOnce = 0 # So we can break
+ while (runOnce == 0):
+ if funName in safeFuns:
+ argTypes = []
+ for i in range(len(a.args)):
+ eventual = eventualType(a.args[i])
+ if eventual == None:
+ return [a]
+ argTypes.append(eventual)
+
+ if funName in ["max", "min"]:
+ break # Special functions
+
+ for key in funDict[funName]:
+ if len(key) != len(argTypes):
+ continue
+ for i in range(len(key)):
+ if not (key[i] == argTypes[i] or issubclass(argTypes[i], key[i])):
+ break
+ else:
+ break
+ else:
+ return [a]
+ break # found one that works
+ else:
+ return [a]
+ elif type(a) == ast.Subscript:
+ if eventualType(a.value) not in [str, list, tuple]:
+ return [a]
+ elif type(a) == ast.Name:
+ # If it's an undefined variable, it might crash
+ if hasattr(a, "randomVar"):
+ return [a]
+ elif type(a) == ast.Slice:
+ if a.lower != None and eventualType(a.lower) != int:
+ return [a]
+ if a.upper != None and eventualType(a.upper) != int:
+ return [a]
+ if a.step != None and eventualType(a.step) != int:
+ return [a]
+ elif type(a) in [ast.Assert, ast.Import, ast.ImportFrom, ast.Attribute, ast.Index]:
+ return [a]
+
+ allCrashes = []
+ for child in ast.iter_child_nodes(a):
+ allCrashes += crashesOn(child)
+ return allCrashes
+
+def isNegation(a, b):
+ """Is a the negation of b?"""
+ return compareASTs(deMorganize(ast.UnaryOp(ast.Not(), deepcopy(a))), b, checkEquality=True) == 0
+
+def reverse(op):
+ """Reverse the direction of the comparison for normalization purposes"""
+ rev = not op.reversed if hasattr(op, "reversed") else True
+ if type(op) == ast.Gt:
+ newOp = ast.Lt()
+ transferMetaData(op, newOp)
+ newOp.reversed = rev
+ return newOp
+ elif type(op) == ast.GtE:
+ newOp = ast.LtE()
+ transferMetaData(op, newOp)
+ newOp.reversed = rev
+ return newOp
+ else:
+ return op # Do not change!
+
+def orderCommutativeOperations(a):
+ """Order all expressions that are in commutative operations"""
+ """TODO: add commutative function lines?"""
+ if not isinstance(a, ast.AST):
+ return a
+ # If branches can be commutative as long as their tests are disjoint
+ if type(a) == ast.If:
+ a = applyToChildren(a, orderCommutativeOperations)
+ # If the else is (strictly) shorter than the body, switch them
+ if len(a.orelse) != 0 and len(a.body) > len(a.orelse):
+ newTest = ast.UnaryOp(ast.Not(addedNotOp=True), a.test)
+ transferMetaData(a.test, newTest)
+ newTest.negated = True
+ newTest = deMorganize(newTest)
+ a.test = newTest
+ (a.body,a.orelse) = (a.orelse,a.body)
+
+ # Then collect all the branches. The leftover orelse is the final else
+ branches = [(a.test, a.body, a.global_id)]
+ orElse = a.orelse
+ while len(orElse) == 1 and type(orElse[0]) == ast.If:
+ branches.append((orElse[0].test, orElse[0].body, orElse[0].global_id))
+ orElse = orElse[0].orelse
+
+ # If we have branches to order...
+ if len(branches) != 1:
+ # Sort the branches based on their tests
+ # We have to sort carefully because of the possibility for crashing
+ isSorted = False
+ while not isSorted:
+ isSorted = True
+ for i in range(len(branches)-1):
+ # First, do we even want to swap these two?
+ # Branch tests MUST be disjoint to be swapped- otherwise, we break semantics
+ if areDisjoint(branches[i][0], branches[i+1][0]) and \
+ compareASTs(branches[i][0], branches[i+1][0]) > 0:
+ if not (couldCrash(branches[i][0]) or couldCrash(branches[i+1][0])):
+ (branches[i],branches[i+1]) = (branches[i+1],branches[i])
+ isSorted = False
+ # Two values can be swapped if they crash on the SAME thing
+ elif couldCrash(branches[i][0]) and couldCrash(branches[i+1][0]):
+ # Check to see if they crash on the same things
+ l1 = sorted(crashesOn(branches[i][0]), key=functools.cmp_to_key(compareASTs))
+ l2 = sorted(crashesOn(branches[i+1][0]), key=functools.cmp_to_key(compareASTs))
+ if compareASTs(l1, l2, checkEquality=True) == 0:
+ (branches[i],branches[i+1]) = (branches[i+1],branches[i])
+ isSorted = False
+ # Do our last two branches nicely form an if/else already?
+ if len(orElse) == 0 and isNegation(branches[-1][0], branches[-2][0]):
+ starter = branches[-1][1] # skip the if
+ else:
+ starter = [ast.If(branches[-1][0], branches[-1][1], orElse, global_id=branches[-1][2])]
+ # Create the new conditional tree
+ for i in range(len(branches)-2, -1, -1):
+ starter = [ast.If(branches[i][0], branches[i][1], starter, global_id=branches[i][2])]
+ a = starter[0]
+ return a
+ elif type(a) == ast.BoolOp:
+ # If all the values are booleans and won't crash, we can sort them
+ canSort = True
+ for i in range(len(a.values)):
+ a.values[i] = orderCommutativeOperations(a.values[i])
+ if couldCrash(a.values[i]) or eventualType(a.values[i]) != bool or containsTokenStepString(a.values[i]):
+ canSort = False
+
+ if canSort:
+ a.values = sorted(a.values, key=functools.cmp_to_key(compareASTs))
+ else:
+ # Even if there are some problems, we can partially sort. See above
+ isSorted = False
+ while not isSorted:
+ isSorted = True
+ for i in range(len(a.values)-1):
+ if compareASTs(a.values[i], a.values[i+1]) > 0 and \
+ eventualType(a.values[i]) == bool and eventualType(a.values[i+1]) == bool:
+ if not (couldCrash(a.values[i]) or couldCrash(a.values[i+1])):
+ (a.values[i],a.values[i+1]) = (a.values[i+1],a.values[i])
+ isSorted = False
+ # Two values can also be swapped if they crash on the SAME thing
+ elif couldCrash(a.values[i]) and couldCrash(a.values[i+1]):
+ # Check to see if they crash on the same things
+ l1 = sorted(crashesOn(a.values[i]), key=functools.cmp_to_key(compareASTs))
+ l2 = sorted(crashesOn(a.values[i+1]), key=functools.cmp_to_key(compareASTs))
+ if compareASTs(l1, l2, checkEquality=True) == 0:
+ (a.values[i],a.values[i+1]) = (a.values[i+1],a.values[i])
+ isSorted = False
+ return a
+ elif type(a) == ast.BinOp:
+ top = type(a.op)
+ l = a.left = orderCommutativeOperations(a.left)
+ r = a.right = orderCommutativeOperations(a.right)
+
+ # Don't reorder if we're currently walking through hint steps
+ if containsTokenStepString(l) or containsTokenStepString(r):
+ return a
+
+ # TODO: what about possible crashes?
+ # Certain operands are commutative
+ if (top in [ast.Mult, ast.BitOr, ast.BitXor, ast.BitAnd]) or \
+ ((top == ast.Add) and ((eventualType(l) in [int, float, bool]) or \
+ (eventualType(r) in [int, float, bool]))):
+ # Break the chain of binary operations into a list of the
+ # operands over the same op, then sort the operands
+ operands = [[l, a.op], [r, None]]
+ changeMade = True
+ i = 0
+ while i < len(operands):
+ [operand, op] = operands[i]
+ if type(operand) == ast.BinOp and type(operand.op) == top:
+ operands[i:i+1] = [[operand.left, operand.op], [operand.right, op]]
+ else:
+ i += 1
+ operands = sorted(operands, key=functools.cmp_to_key(lambda x,y : compareASTs(x[0], y[0])))
+ for i in range(len(operands)-1): # push all the ops forward
+ if operands[i][1] == None:
+ operands[i][1] = operands[i+1][1]
+ operands[i+1][1] = None
+ # Then put them back into a single expression, descending to the left
+ left = operands[0][0]
+ for i in range(1, len(operands)):
+ left = ast.BinOp(left, operands[i-1][1], operands[i][0], orderedBinOp=True)
+ transferMetaData(a, left)
+ return left
+ elif top == ast.Add:
+ # This might be concatenation, not addition
+ if type( r ) == ast.BinOp and type(r.op) == top:
+ # We want the operators to descend to the left
+ a.left = orderCommutativeOperations(ast.BinOp(l, r.op, r.left, global_id=r.global_id))
+ a.right = r.right
+ return a
+ elif type(a) == ast.Dict:
+ for i in range(len(a.keys)):
+ a.keys[i] = orderCommutativeOperations(a.keys[i])
+ a.values[i] = orderCommutativeOperations(a.values[i])
+
+ pairs = list(zip(a.keys, a.values))
+ pairs.sort(key=functools.cmp_to_key(lambda x,y : compareASTs(x[0],y[0]))) # sort by keys
+ k, v = zip(*pairs) if len(pairs) > 0 else ([], [])
+ a.keys = list(k)
+ a.values = list(v)
+ return a
+ elif type(a) == ast.Compare:
+ l = a.left = orderCommutativeOperations(a.left)
+ r = orderCommutativeOperations(a.comparators[0])
+ a.comparators = [r]
+
+ # Don't reorder when we're doing hint steps
+ if containsTokenStepString(l) or containsTokenStepString(r):
+ return a
+
+ if (type(a.ops[0]) in [ast.Eq, ast.NotEq]):
+ # Equals and not-equals are commutative
+ if compareASTs(l, r) > 0:
+ a.left, a.comparators[0] = a.comparators[0], a.left
+ elif (type(a.ops[0]) in [ast.Gt, ast.GtE]):
+ # We'll always use < and <=, just so everything's the same
+ a.ops = [reverse(a.ops[0])]
+ a.left, a.comparators[0] = a.comparators[0], a.left
+ elif (type(a.ops[0]) in [ast.In, ast.NotIn]):
+ if type(r) == ast.List:
+ # If it's a list of items, sort the list
+ # TODO: should we implement crashable sorting here?
+ for i in range(len(r.elts)):
+ if couldCrash(r.elts[i]):
+ break # don't sort if there'a a crash!
+ else:
+ r.elts = sorted(r.elts, key=functools.cmp_to_key(compareASTs))
+ # Then remove duplicates
+ i = 0
+ while i < len(r.elts) - 1:
+ if compareASTs(r.elts[i], r.elts[i+1], checkEquality=True) == 0:
+ r.elts.pop(i+1)
+ else:
+ i += 1
+ return a
+ elif type(a) == ast.Call:
+ if type(a.func) == ast.Name:
+ # These functions are commutative and show up a lot
+ if a.func.id in ["min", "max"]:
+ crashable = False
+ for i in range(len(a.args)):
+ a.args[i] = orderCommutativeOperations(a.args[i])
+ if couldCrash(a.args[i]) or containsTokenStepString(a.args[i]):
+ crashable = True
+ # TODO: crashable sorting here?
+ if not crashable:
+ a.args = sorted(a.args, key=functools.cmp_to_key(compareASTs))
+ return a
+ return applyToChildren(a, orderCommutativeOperations)
+
+def deMorganize(a):
+ """Apply De Morgan's law throughout the code in order to canonicalize"""
+ if not isinstance(a, ast.AST):
+ return a
+ # We only care about statements beginning with not
+ if type(a) == ast.UnaryOp and type(a.op) == ast.Not:
+ oper = a.operand
+ top = type(oper)
+
+ # not (blah and gah) == (not blah or not gah)
+ if top == ast.BoolOp:
+ oper.op = negate(oper.op)
+ for i in range(len(oper.values)):
+ oper.values[i] = deMorganize(negate(oper.values[i]))
+ oper.negated = not oper.negated if hasattr(oper, "negated") else True
+ transferMetaData(a, oper)
+ return oper
+ # not a < b == a >= b
+ elif top == ast.Compare:
+ oper.left = deMorganize(oper.left)
+ oper.ops = [negate(oper.ops[0])]
+ oper.comparators = [deMorganize(oper.comparators[0])]
+ oper.negated = not oper.negated if hasattr(oper, "negated") else True
+ transferMetaData(a, oper)
+ return oper
+ # not not blah == blah
+ elif top == ast.UnaryOp and type(oper.op) == ast.Not:
+ oper.operand = deMorganize(oper.operand)
+ if eventualType(oper.operand) != bool:
+ return a
+ oper.operand.negated = not oper.operand.negated if hasattr(oper.operand, "negated") else True
+ return oper.operand
+ elif top == ast.NameConstant:
+ if oper.value in [True, False]:
+ oper = negate(oper)
+ transferMetaData(a, oper)
+ return oper
+ elif oper.value == None:
+ tmp = ast.NameConstant(True)
+ transferMetaData(a, tmp)
+ tmp.negated = True
+ return tmp
+ else:
+ log("Unknown NameConstant: " + str(oper.value), "bug")
+
+ return applyToChildren(a, deMorganize)
+
+##### CLEANUP FUNCTIONS #####
+
+def cleanupEquals(a):
+ """Gets rid of silly blah == True statements that students make"""
+ if not isinstance(a, ast.AST):
+ return a
+ if type(a) == ast.Call:
+ a.func = cleanupEquals(a.func)
+ for i in range(len(a.args)):
+ # But test expressions don't carry through to function arguments
+ a.args[i] = cleanupEquals(a.args[i])
+ return a
+ elif type(a) == ast.Compare and type(a.ops[0]) in [ast.Eq, ast.NotEq]:
+ l = a.left = cleanupEquals(a.left)
+ r = cleanupEquals(a.comparators[0])
+ a.comparators = [r]
+ if type(l) == ast.NameConstant and l.value in [True, False]:
+ (l,r) = (r,l)
+ # If we have (boolean expression) == True
+ if type(r) == ast.NameConstant and r.value in [True, False] and (eventualType(l) == bool):
+ # Matching types
+ if (type(a.ops[0]) == ast.Eq and r.value == True) or \
+ (type(a.ops[0]) == ast.NotEq and r.value == False):
+ transferMetaData(a, l) # make sure to keep the original location
+ return l
+ else:
+ tmp = ast.UnaryOp(ast.Not(addedNotOp=True), l)
+ transferMetaData(a, tmp)
+ return tmp
+ else:
+ return a
+ else:
+ return applyToChildren(a, cleanupEquals)
+
+def cleanupBoolOps(a):
+ """When possible, combine adjacent boolean expressions"""
+ """Note- we are assuming that all ops are the first op (as is done in the simplify function)"""
+ if not isinstance(a, ast.AST):
+ return a
+ if type(a) == ast.BoolOp:
+ allTypesWork = True
+ for i in range(len(a.values)):
+ a.values[i] = cleanupBoolOps(a.values[i])
+ if eventualType(a.values[i]) != bool or hasattr(a.values[i], "multiComp"):
+ allTypesWork = False
+
+ # We can't reduce if the types aren't all booleans
+ if not allTypesWork:
+ return a
+
+ i = 0
+ while i < len(a.values) - 1:
+ current = a.values[i]
+ next = a.values[i+1]
+ # (a and b and c and d) or (a and e and d) == a and ((b and c) or e) and d
+ if type(current) == type(next) == ast.BoolOp:
+ if type(current.op) == type(next.op):
+ minlength = min(len(current.values), len(next.values)) # shortest length
+
+ # First, check for all identical values from the front
+ j = 0
+ while j < minlength:
+ if compareASTs(current.values[j], next.values[j], checkEquality=True) != 0:
+ break
+ j += 1
+
+ # Same values in both, so get rid of the latter line
+ if j == len(current.values) == len(next.values):
+ a.values.pop(i+1)
+ continue
+ i += 1
+ ### If reduced to one item, just return that item
+ return a.values[0] if (len(a.values) == 1) else a
+ return applyToChildren(a, cleanupBoolOps)
+
+def cleanupRanges(a):
+ """Remove any range shenanigans, because Python lets you include unneccessary values"""
+ if not isinstance(a, ast.AST):
+ return a
+ if type(a) == ast.Call:
+ if type(a.func) == ast.Name:
+ if a.func.id in ["range"]:
+ if len(a.args) == 3:
+ # The step defaults to 1!
+ if type(a.args[2]) == ast.Num and a.args[2].n == 1:
+ a.args = a.args[:-1]
+ if len(a.args) == 2:
+ # The start defaults to 0!
+ if type(a.args[0]) == ast.Num and a.args[0].n == 0:
+ a.args = a.args[1:]
+ return applyToChildren(a, cleanupRanges)
+
+def cleanupSlices(a):
+ """Remove any slice shenanigans, because Python lets you include unneccessary values"""
+ if not isinstance(a, ast.AST):
+ return a
+ if type(a) == ast.Subscript:
+ if type(a.slice) == ast.Slice:
+ # Lower defaults to 0
+ if a.slice.lower != None and type(a.slice.lower) == ast.Num and a.slice.lower.n == 0:
+ a.slice.lower = None
+ # Upper defaults to len(value)
+ if a.slice.upper != None and type(a.slice.upper) == ast.Call and \
+ type(a.slice.upper.func) == ast.Name and a.slice.upper.func.id == "len":
+ if compareASTs(a.value, a.slice.upper.args[0], checkEquality=True) == 0:
+ a.slice.upper = None
+ # Step defaults to 1
+ if a.slice.step != None and type(a.slice.step) == ast.Num and a.slice.step.n == 1:
+ a.slice.step = None
+ return applyToChildren(a, cleanupSlices)
+
+def cleanupTypes(a):
+ """Remove any unneccessary type mappings"""
+ if not isinstance(a, ast.AST):
+ return a
+ # No need to cast something if it'll be changed anyway by a binary operation
+ if type(a) == ast.BinOp:
+ a.left = cleanupTypes(a.left)
+ a.right = cleanupTypes(a.right)
+ # Ints become floats naturally
+ if eventualType(a.left) == eventualType(a.right) == float:
+ if type(a.right) == ast.Call and type(a.right.func) == ast.Name and \
+ a.right.func.id == "float" and len(a.right.args) == 1 and len(a.right.keywords) == 0 and \
+ eventualType(a.right.args[0]) in [int, float]:
+ a.right = a.right.args[0]
+ elif type(a.left) == ast.Call and type(a.left.func) == ast.Name and \
+ a.left.func.id == "float" and len(a.left.args) == 1 and len(a.left.keywords) == 0 and \
+ eventualType(a.left.args[0]) in [int, float]:
+ a.left = a.left.args[0]
+ return a
+ elif type(a) == ast.Call and type(a.func) == ast.Name and len(a.args) == 1 and len(a.keywords) == 0:
+ a.func = cleanupTypes(a.func)
+ a.args = [cleanupTypes(a.args[0])]
+ # If the type already matches, no need to cast it
+ funName = a.func.id
+ argType = eventualType(a.args[0])
+ if type(a.func) == ast.Name:
+ if (funName == "float" and argType == float) or \
+ (funName == "int" and argType == int) or \
+ (funName == "bool" and argType == bool) or \
+ (funName == "str" and argType == str):
+ return a.args[0]
+ return applyToChildren(a, cleanupTypes)
+
+def turnPositive(a):
+ """Take a negative number and make it positive"""
+ if type(a) == ast.UnaryOp and type(a.op) == ast.USub:
+ return a.operand
+ elif type(a) == ast.Num and type(a.n) != complex and a.n < 0:
+ a.n = a.n * -1
+ return a
+ else:
+ log("transformations\tturnPositive\tFailure: " + str(a), "bug")
+ return a
+
+def isNegative(a):
+ """Is the give number negative?"""
+ if type(a) == ast.UnaryOp and type(a.op) == ast.USub:
+ return True
+ elif type(a) == ast.Num and type(a.n) != complex and a.n < 0:
+ return True
+ else:
+ return False
+
+def cleanupNegations(a):
+ """Remove unneccessary negations"""
+ if not isinstance(a, ast.AST):
+ return a
+ elif type(a) == ast.BinOp:
+ a.left = cleanupNegations(a.left)
+ a.right = cleanupNegations(a.right)
+
+ if type(a.op) == ast.Add:
+ # x + (-y)
+ if isNegative(a.right):
+ a.right = turnPositive(a.right)
+ a.op = ast.Sub(global_id=a.op.global_id, num_negated=True)
+ return a
+ # (-x) + y
+ elif isNegative(a.left):
+ if couldCrash(a.left) and couldCrash(a.right):
+ return a # can't switch if it'll change the message
+ else:
+ (a.left,a.right) = (a.right,turnPositive(a.left))
+ a.op = ast.Sub(global_id=a.op.global_id, num_negated=True)
+ return a
+ elif type(a.op) == ast.Sub:
+ # x - (-y)
+ if isNegative(a.right):
+ a.right = turnPositive(a.right)
+ a.op = ast.Add(global_id=a.op.global_id, num_negated=True)
+ return a
+ elif type(a.right) == ast.BinOp:
+ # x - (y + z) = x + (-y - z)
+ if type(a.right.op) == ast.Add:
+ a.right.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.right.left, addedOther=True))
+ a.right.op = ast.Sub(global_id=a.right.op.global_id, num_negated=True)
+ a.op = ast.Add(global_id=a.op.global_id, num_negated=True)
+ return a
+ # x - (y - z) = x + (-y + z) = x + (z - y)
+ elif type(a.right.op) == ast.Sub:
+ if couldCrash(a.right.left) and couldCrash(a.right.right):
+ a.right.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.right.left, addedOther=True))
+ a.right.op = ast.Add(global_id=a.right.op.global_id, num_negated=True)
+ a.op = ast.Add(global_id=a.op.global_id, num_negated=True)
+ return a
+ else:
+ (a.right.left, a.right.right) = (a.right.right, a.right.left)
+ a.op = ast.Add(global_id=a.op.global_id, num_negated=True)
+ return a
+ # Move negations to the outer part of multiplications
+ elif type(a.op) == ast.Mult:
+ # -x * -y
+ if isNegative(a.left) and isNegative(a.right):
+ a.left = turnPositive(a.left)
+ a.right = turnPositive(a.right)
+ return a
+ # -x * y = -(x*y)
+ elif isNegative(a.left):
+ if eventualType(a.right) in [int, float]:
+ a.left = turnPositive(a.left)
+ return cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a, addedOther=True))
+ # x * -y = -(x*y)
+ elif isNegative(a.right):
+ if eventualType(a.left) in [int, float]:
+ a.right = turnPositive(a.right)
+ return cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a, addedOther=True))
+ elif type(a.op) in [ast.Div, ast.FloorDiv]:
+ if isNegative(a.left) and isNegative(a.right):
+ a.left = turnPositive(a.left)
+ a.right = turnPositive(a.right)
+ return a
+ return a
+ elif type(a) == ast.UnaryOp:
+ a.operand = cleanupNegations(a.operand)
+ if type(a.op) == ast.USub:
+ if type(a.operand) == ast.BinOp:
+ # -(x + y) = -x - y
+ if type(a.operand.op) == ast.Add:
+ a.operand.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.operand.left, addedOther=True))
+ a.operand.op = ast.Sub(global_id=a.operand.op.global_id, num_negated=True)
+ transferMetaData(a, a.operand)
+ return a.operand
+ # -(x - y) = -x + y = y - x
+ elif type(a.operand.op) == ast.Sub:
+ if couldCrash(a.operand.left) and couldCrash(a.operand.right):
+ a.operand.left = cleanupNegations(ast.UnaryOp(ast.USub(addedOtherOp=True), a.operand.left, addedOther=True))
+ a.operand.op = ast.Add(global_id=a.operand.op.global_id, num_negated=True)
+ transferMetaData(a, a.operand)
+ return a.operand
+ else:
+ (a.operand.left,a.operand.right) = (a.operand.right,a.operand.left)
+ transferMetaData(a, a.operand)
+ return a.operand
+ return a
+ # Special case for absolute value
+ elif type(a) == ast.Call and type(a.func) == ast.Name and a.func.id == "abs" and len(a.args) == 1:
+ a.args[0] = cleanupNegations(a.args[0])
+ if type(a.args[0]) == ast.UnaryOp and type(a.args[0].op) == ast.USub:
+ a.args[0] = a.args[0].operand
+ elif type(a.args[0]) == ast.BinOp and type(a.args[0].op) == ast.Sub:
+ if not (couldCrash(a.args[0].left) and couldCrash(a.args[0].right)) and \
+ compareASTs(a.args[0].left, a.args[0].right) > 0:
+ (a.args[0].left,a.args[0].right) = (a.args[0].right,a.args[0].left)
+ return a
+ else:
+ return applyToChildren(a, cleanupNegations)
+
+### CONDITIONAL TRANSFORMATIONS ###
+
+def combineConditionals(a):
+ """When possible, combine conditional branches"""
+ if not isinstance(a, ast.AST):
+ return a
+ elif type(a) == ast.If:
+ for i in range(len(a.body)):
+ a.body[i] = combineConditionals(a.body[i])
+ for i in range(len(a.orelse)):
+ a.orelse[i] = combineConditionals(a.orelse[i])
+
+ # if a: if b: x can be - if a and b: x
+ if (len(a.orelse) == 0) and (len(a.body) == 1) and \
+ (type(a.body[0]) == ast.If) and (len(a.body[0].orelse) == 0):
+ a.test = ast.BoolOp(ast.And(combinedConditionalOp=True), [a.test, a.body[0].test], combinedConditional=True)
+ a.body = a.body[0].body
+ # if a: x elif b: x can be - if a or b: x
+ elif (len(a.orelse) == 1) and \
+ (type(a.orelse[0]) == ast.If) and (len(a.orelse[0].orelse) == 0):
+ if compareASTs(a.body, a.orelse[0].body, checkEquality=True) == 0:
+ a.test = ast.BoolOp(ast.Or(combinedConditionalOp=True), [a.test, a.orelse[0].test], combinedConditional=True)
+ a.orelse = []
+ return a
+ else:
+ return applyToChildren(a, combineConditionals)
+
+def staticVars(l, vars):
+ """Determines whether the given lines change the given variables"""
+ # First, if one of the variables can be modified, there might be a problem
+ mutableVars = []
+ for var in vars:
+ if (not (hasattr(var, "type") and (var.type in [int, float, str, bool]))):
+ mutableVars.append(var)
+
+ for i in range(len(l)):
+ if type(l[i]) == ast.Assign:
+ for var in vars:
+ if var.id in allVariableNamesUsed(l[i].targets[0]):
+ return False
+ elif type(l[i]) == ast.AugAssign:
+ for var in vars:
+ if var.id in allVariableNamesUsed(l[i].target):
+ return False
+ elif type(l[i]) in [ast.If, ast.While]:
+ if not (staticVars(l[i].body, vars) and staticVars(l[i].orelse, vars)):
+ return False
+ elif type(l[i]) == ast.For:
+ for var in vars:
+ if var.id in allVariableNamesUsed(l[i].target):
+ return False
+ if not (staticVars(l[i].body, vars) and staticVars(l[i].orelse, vars)):
+ return False
+ elif type(l[i]) in [ast.FunctionDef, ast.ClassDef, ast.Try, ast.With]:
+ log("transformations\tstaticVars\tMissing type: " + str(type(l[i])), "bug")
+
+ # If a mutable variable is used, we can't trust it
+ for var in mutableVars:
+ if var.id in allVariableNamesUsed(l[i]):
+ return False
+ return True
+
+def getIfBranches(a):
+ """Gets all the branches of an if statement. Will only work if each else has a single line"""
+ if type(a) != ast.If:
+ return None
+
+ if len(a.orelse) == 0:
+ return [a]
+ elif len(a.orelse) == 1:
+ tmp = getIfBranches(a.orelse[0])
+ if tmp == None:
+ return None
+ return [a] + tmp
+ else:
+ return None
+
+def allVariablesUsed(a):
+ """Gathers all the variable names used in the ast"""
+ if type(a) == list:
+ variables = []
+ for x in a:
+ variables += allVariablesUsed(x)
+ return variables
+
+ if not isinstance(a, ast.AST):
+ return []
+ elif type(a) == ast.Name:
+ return [a]
+ elif type(a) == ast.Assign:
+ variables = allVariablesUsed(a.value)
+ for target in a.targets:
+ if type(target) == ast.Name:
+ pass
+ elif type(target) in [ast.Tuple, ast.List]:
+ for elt in target.elts:
+ if type(elt) == ast.Name:
+ pass
+ else:
+ variables += allVariablesUsed(elt)
+ else:
+ variables += allVariablesUsed(target)
+ return variables
+ else:
+ variables = []
+ for child in ast.iter_child_nodes(a):
+ variables += allVariablesUsed(child)
+ return variables
+
+def conditionalRedundancy(a):
+ """When possible, remove redundant lines from conditionals and combine conditionals."""
+ if type(a) == ast.Module:
+ for i in range(len(a.body)):
+ if type(a.body[i]) == ast.FunctionDef:
+ a.body[i] = conditionalRedundancy(a.body[i])
+ return a
+ elif type(a) == ast.FunctionDef:
+ a.body = conditionalRedundancy(a.body)
+ return a
+
+ if type(a) == list:
+ i = 0
+ while i < len(a):
+ stmt = a[i]
+ if type(stmt) == ast.If:
+ stmt.body = conditionalRedundancy(stmt.body)
+ stmt.orelse = conditionalRedundancy(stmt.orelse)
+
+ # If a line appears in both, move it outside the conditionals
+ if len(stmt.body) > 0 and len(stmt.orelse) > 0 and compareASTs(stmt.body[-1], stmt.orelse[-1], checkEquality=True) == 0:
+ nextLine = stmt.body[-1]
+ nextLine.second_global_id = stmt.orelse[-1].global_id
+ stmt.body = stmt.body[:-1]
+ stmt.orelse = stmt.orelse[:-1]
+ stmt.moved_line = nextLine.global_id
+ # Remove the if statement if both if and else are empty
+ if len(stmt.body) == 0 and len(stmt.orelse) == 0:
+ newLine = ast.Expr(stmt.test)
+ transferMetaData(stmt, newLine)
+ a[i:i+1] = [newLine, nextLine]
+ # Switch if and else if if is empty
+ elif len(stmt.body) == 0:
+ stmt.test = ast.UnaryOp(ast.Not(addedNotOp=True), stmt.test, addedNot=True)
+ stmt.body = stmt.orelse
+ stmt.orelse = []
+ a[i:i+1] = [stmt, nextLine]
+ else:
+ a[i:i+1] = [stmt, nextLine]
+ continue # skip incrementing so that we check the conditional again
+ # Join adjacent, disjoint ifs
+ elif i+1 < len(a) and type(a[i+1]) == ast.If:
+ branches1 = getIfBranches(stmt)
+ branches2 = getIfBranches(a[i+1])
+ if branches1 != None and branches2 != None:
+ # First, check whether any vars used in the second set of branches will be changed by the first set
+ testVars = sum(map(lambda b : allVariablesUsed(b.test), branches2), [])
+ for branch in branches1:
+ if not staticVars(branch.body, testVars):
+ break
+ else:
+ branchCombos = [(x,y) for y in branches2 for x in branches1]
+ for (branch1,branch2) in branchCombos:
+ if not areDisjoint(branch1.test, branch2.test):
+ break
+ else:
+ # We know the last else branch is empty- fill it with the next tree!
+ branches1[-1].orelse = [a[i+1]]
+ a.pop(i+1)
+ continue # check this conditional again
+ elif type(stmt) == ast.FunctionDef:
+ stmt.body = conditionalRedundancy(stmt.body)
+ elif type(stmt) in [ast.While, ast.For]:
+ stmt.body = conditionalRedundancy(stmt.body)
+ stmt.orelse = conditionalRedundancy(stmt.orelse)
+ elif type(stmt) == ast.ClassDef:
+ for x in range(len(stmt.body)):
+ if type(stmt.body[x]) == ast.FunctionDef:
+ stmt.body[x] = conditionalRedundancy(stmt.body[x])
+ elif type(stmt) == ast.Try:
+ stmt.body = conditionalRedundancy(stmt.body)
+ for x in range(len(stmt.handlers)):
+ stmt.handlers[x].body = conditionalRedundancy(stmt.handlers[x].body)
+ stmt.orelse = conditionalRedundancy(stmt.orelse)
+ stmt.finalbody = conditionalRedundancy(stmt.finalbody)
+ elif type(stmt) == ast.With:
+ stmt.body = conditionalRedundancy(stmt.body)
+ else:
+ pass
+ i += 1
+ return a
+ else:
+ log("transformations\tconditionalRedundancy\tStrange type: " + str(type(a)), "bug")
+
+def collapseConditionals(a):
+ """When possible, combine adjacent conditionals"""
+ if type(a) == ast.Module:
+ for i in range(len(a.body)):
+ if type(a.body[i]) == ast.FunctionDef:
+ a.body[i] = collapseConditionals(a.body[i])
+ return a
+ elif type(a) == ast.FunctionDef:
+ a.body = collapseConditionals(a.body)
+ return a
+
+ if type(a) == list:
+ l = a
+ i = len(l) - 1
+
+ # Go through the lines backwards, since we're collapsing conditionals upwards
+ while i >= 0:
+ stmt = l[i]
+ if type(stmt) == ast.If:
+ stmt.body = collapseConditionals(stmt.body)
+ stmt.orelse = collapseConditionals(stmt.orelse)
+
+ # First, check to see if we can collapse across the if and its else
+ if len(l[i].body) == 1 and len(l[i].orelse) == 1:
+ ifLine = l[i].body[0]
+ elseLine = l[i].orelse[0]
+
+ # This only works for Assign and Return
+ if type(ifLine) == type(elseLine) == ast.Assign and \
+ compareASTs(ifLine.targets, elseLine.targets, checkEquality=True) == 0:
+ pass
+ elif type(ifLine) == ast.Return and type(elseLine) == ast.Return:
+ pass
+ else:
+ i -= 1
+ continue # skip past this
+
+ if type(ifLine.value) == type(elseLine.value) == ast.Name and \
+ ifLine.value.id in ['True', 'False'] and elseLine.value.id in ['True', 'False']:
+ if ifLine.value.id == elseLine.value.id:
+ # If they both return the same thing, just replace the if altogether.
+ # But keep the test in case it crashes- we'll remove it later
+ ifLine.global_id = None # we're replacing the whole if statement
+ l[i:i+1] = [ast.Expr(l[i].test, addedOther=True, moved_line=ifLine.global_id), ifLine]
+ elif eventualType(l[i].test) == bool:
+ testVal = l[i].test
+ if ifLine.value.id == 'True':
+ newVal = testVal
+ else:
+ newVal = ast.UnaryOp(ast.Not(addedNotOp=True), testVal, negated=True, collapsedExpr=True)
+
+ if type(ifLine) == ast.Assign:
+ newLine = ast.Assign(ifLine.targets, newVal)
+ else:
+ newLine = ast.Return(newVal)
+ transferMetaData(l[i], newLine)
+ l[i] = newLine
+ # Next, check to see if we can collapse across the if and surrounding lines
+ elif len(l[i].body) == 1 and len(l[i].orelse) == 0:
+ ifLine = l[i].body[0]
+ # First, check to see if the current and prior have the same return bodies
+ if i != 0 and type(l[i-1]) == ast.If and \
+ len(l[i-1].body) == 1 and len(l[i-1].orelse) == 0 and \
+ type(ifLine) == ast.Return and compareASTs(ifLine, l[i-1].body[0], checkEquality=True) == 0:
+ # If they do, combine their tests with an Or and get rid of this line
+ l[i-1].test = ast.BoolOp(ast.Or(combinedConditionalOp=True), [l[i-1].test, l[i].test], combinedConditional=True)
+ l[i-1].second_global_id = l[i].global_id
+ l.pop(i)
+ # Then, check whether the current and latter lines have the same returns
+ elif i != len(l) - 1 and type(ifLine) == type(l[i+1]) == ast.Return and \
+ type(ifLine.value) == type(l[i+1].value) == ast.Name and \
+ ifLine.value.id in ['True', 'False'] and l[i+1].value.id in ['True', 'False']:
+ if ifLine.value.id == l[i+1].value.id:
+ # No point in keeping the if line- just use the return
+ l[i] = ast.Expr(l[i].test, addedOther=True)
+ else:
+ if eventualType(l[i].test) == bool:
+ testVal = l[i].test
+ if ifLine.value.id == 'True':
+ newLine = ast.Return(testVal)
+ else:
+ newLine = ast.Return(ast.UnaryOp(ast.Not(addedNotOp=True), testVal, negated=True, collapsedExpr=True))
+ transferMetaData(l[i], newLine)
+ l[i] = newLine
+ l.pop(i+1) # get rid of the extra return
+ elif type(stmt) == ast.FunctionDef:
+ stmt.body = collapseConditionals(stmt.body)
+ elif type(stmt) in [ast.While, ast.For]:
+ stmt.body = collapseConditionals(stmt.body)
+ stmt.orelse = collapseConditionals(stmt.orelse)
+ elif type(stmt) == ast.ClassDef:
+ for i in range(len(stmt.body)):
+ if type(stmt.body[i]) == ast.FunctionDef:
+ stmt.body[i] = collapseConditionals(stmt.body[i])
+ elif type(stmt) == ast.Try:
+ stmt.body = collapseConditionals(stmt.body)
+ for i in range(len(stmt.handlers)):
+ stmt.handlers[i].body = collapseConditionals(stmt.handlers[i].body)
+ stmt.orelse = collapseConditionals(stmt.orelse)
+ stmt.finalbody = collapseConditionals(stmt.finalbody)
+ elif type(stmt) == ast.With:
+ stmt.body = collapseConditionals(stmt.body)
+ else:
+ pass
+ i -= 1
+ return l
+ else:
+ log("transformations\tcollapseConditionals\tStrange type: " + str(type(a)), "bug")