summaryrefslogtreecommitdiff
path: root/canonicalize/astTools.py
diff options
context:
space:
mode:
Diffstat (limited to 'canonicalize/astTools.py')
-rw-r--r--canonicalize/astTools.py1426
1 files changed, 1426 insertions, 0 deletions
diff --git a/canonicalize/astTools.py b/canonicalize/astTools.py
new file mode 100644
index 0000000..b84b2c5
--- /dev/null
+++ b/canonicalize/astTools.py
@@ -0,0 +1,1426 @@
+import ast, copy, pickle
+from .tools import log
+from .namesets import *
+from .display import printFunction
+
+def cmp(a, b):
+ if type(a) == type(b) == complex:
+ return (a.real > b.real) - (a.real < b.real)
+ return (a > b) - (a < b)
+
+def builtInName(id):
+ """Determines whether the given id is a built-in name"""
+ if id in builtInNames + exceptionClasses:
+ return True
+ elif id in builtInFunctions.keys():
+ return True
+ elif id in list(allPythonFunctions.keys()) + supportedLibraries:
+ return False
+
+def importedName(id, importList):
+ for imp in importList:
+ if type(imp) == ast.Import:
+ for name in imp.names:
+ if hasattr(name, "asname") and name.asname != None:
+ if id == name.asname:
+ return True
+ else:
+ if id == name.name:
+ return True
+ elif type(imp) == ast.ImportFrom:
+ if hasattr(imp, "module"):
+ if imp.module in supportedLibraries:
+ libMap = libraryMap[imp.module]
+ for name in imp.names:
+ if hasattr(name, "asname") and name.asname != None:
+ if id == name.asname:
+ return True
+ else:
+ if id == name.name:
+ return True
+ else:
+ log("astTools\timportedName\tUnsupported library: " + printFunction(imp), "bug")
+
+ else:
+ log("astTools\timportedName\tWhy no module? " + printFunction(imp), "bug")
+ return False
+
+def isConstant(x):
+ """Determine whether the provided AST is a constant"""
+ return (type(x) in [ast.Num, ast.Str, ast.Bytes, ast.NameConstant])
+
+def isIterableType(t):
+ """Can the given type be iterated over"""
+ return t in [ dict, list, set, str, bytes, tuple ]
+
+def isStatement(a):
+ """Determine whether the given node is a statement (vs an expression)"""
+ return type(a) in [ ast.Module, ast.Interactive, ast.Expression, ast.Suite,
+ ast.FunctionDef, ast.ClassDef, ast.Return, ast.Delete,
+ ast.Assign, ast.AugAssign, ast.For, ast.While,
+ ast.If, ast.With, ast.Raise, ast.Try,
+ ast.Assert, ast.Import, ast.ImportFrom, ast.Global,
+ ast.Expr, ast.Pass, ast.Break, ast.Continue ]
+
+def codeLength(a):
+ """Returns the number of characters in this AST"""
+ if type(a) == list:
+ return sum([codeLength(x) for x in a])
+ return len(printFunction(a))
+
+def applyToChildren(a, f):
+ """Apply the given function to all the children of a"""
+ if a == None:
+ return a
+ for field in a._fields:
+ child = getattr(a, field)
+ if type(child) == list:
+ i = 0
+ while i < len(child):
+ temp = f(child[i])
+ if type(temp) == list:
+ child = child[:i] + temp + child[i+1:]
+ i += len(temp)
+ else:
+ child[i] = temp
+ i += 1
+ else:
+ child = f(child)
+ setattr(a, field, child)
+ return a
+
+def occursIn(sub, super):
+ """Does the first AST occur as a subtree of the second?"""
+ superStatementTypes = [ ast.Module, ast.Interactive, ast.Suite,
+ ast.FunctionDef, ast.ClassDef, ast.For,
+ ast.While, ast.If, ast.With, ast.Try,
+ ast.ExceptHandler ]
+ if (not isinstance(super, ast.AST)):
+ return False
+ if type(sub) == type(super) and compareASTs(sub, super, checkEquality=True) == 0:
+ return True
+ # we know that a statement can never occur in an expression
+ # (or in a non-statement-holding statement), so cut the search off now to save time.
+ if isStatement(sub) and type(super) not in superStatementTypes:
+ return False
+ for child in ast.iter_child_nodes(super):
+ if occursIn(sub, child):
+ return True
+ return False
+
+def countOccurances(a, value):
+ """How many instances of this node type appear in the AST?"""
+ if type(a) == list:
+ return sum([countOccurances(x, value) for x in a])
+ if not isinstance(a, ast.AST):
+ return 0
+
+ count = 0
+ for node in ast.walk(a):
+ if isinstance(node, value):
+ count += 1
+ return count
+
+def countVariables(a, id):
+ """Count the number of times the given variable appears in the AST"""
+ if type(a) == list:
+ return sum([countVariables(x, id) for x in a])
+ if not isinstance(a, ast.AST):
+ return 0
+
+ count = 0
+ for node in ast.walk(a):
+ if type(node) == ast.Name and node.id == id:
+ count += 1
+ return count
+
+def gatherAllNames(a, keep_orig=True):
+ """Gather all names in the tree (variable or otherwise).
+ Names are returned along with their original names
+ (which are used in variable mapping)"""
+ if type(a) == list:
+ allIds = set()
+ for line in a:
+ allIds |= gatherAllNames(line)
+ return allIds
+ if not isinstance(a, ast.AST):
+ return set()
+
+ allIds = set()
+ for node in ast.walk(a):
+ if type(node) == ast.Name:
+ origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None
+ allIds |= set([(node.id, origName)])
+ return allIds
+
+def gatherAllVariables(a, keep_orig=True):
+ """Gather all variable names in the tree. Names are returned along
+ with their original names (which are used in variable mapping)"""
+ if type(a) == list:
+ allIds = set()
+ for line in a:
+ allIds |= gatherAllVariables(line)
+ return allIds
+ if not isinstance(a, ast.AST):
+ return set()
+
+ allIds = set()
+ for node in ast.walk(a):
+ if type(node) == ast.Name or type(node) == ast.arg:
+ currentId = node.id if type(node) == ast.Name else node.arg
+ # Only take variables
+ if not (builtInName(currentId) or hasattr(node, "dontChangeName")):
+ origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None
+ if (currentId, origName) not in allIds:
+ for pair in allIds:
+ if pair[0] == currentId:
+ if pair[1] == None:
+ allIds -= {pair}
+ allIds |= {(currentId, origName)}
+ elif origName == None:
+ pass
+ else:
+ log("astTools\tgatherAllVariables\tConflicting originalIds? " + pair[0] + " : " + pair[1] + " , " + origName + "\n" + printFunction(a), "bug")
+ break
+ else:
+ allIds |= {(currentId, origName)}
+ return allIds
+
+def gatherAllParameters(a, keep_orig=True):
+ """Gather all parameters in the tree. Names are returned along
+ with their original names (which are used in variable mapping)"""
+ if type(a) == list:
+ allIds = set()
+ for line in a:
+ allIds |= gatherAllVariables(line)
+ return allIds
+ if not isinstance(a, ast.AST):
+ return set()
+
+ allIds = set()
+ for node in ast.walk(a):
+ if type(node) == ast.arg:
+ origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None
+ allIds |= set([(node.arg, origName)])
+ return allIds
+
+def gatherAllHelpers(a, restricted_names):
+ """Gather all helper function names in the tree that have been anonymized"""
+ if type(a) != ast.Module:
+ return set()
+ helpers = set()
+ for item in a.body:
+ if type(item) == ast.FunctionDef:
+ if not hasattr(item, "dontChangeName") and item.name not in restricted_names: # this got anonymized
+ origName = item.originalId if hasattr(item, "originalId") else None
+ helpers |= set([(item.name, origName)])
+ return helpers
+
+def gatherAllFunctionNames(a):
+ """Gather all helper function names in the tree that have been anonymized"""
+ if type(a) != ast.Module:
+ return set()
+ helpers = set()
+ for item in a.body:
+ if type(item) == ast.FunctionDef:
+ origName = item.originalId if hasattr(item, "originalId") else None
+ helpers |= set([(item.name, origName)])
+ return helpers
+
+def gatherAssignedVars(targets):
+ """Take a list of assigned variables and extract the names/subscripts/attributes"""
+ if type(targets) != list:
+ targets = [targets]
+ newTargets = []
+ for target in targets:
+ if type(target) in [ast.Tuple, ast.List]:
+ newTargets += gatherAssignedVars(target.elts)
+ elif type(target) in [ast.Name, ast.Subscript, ast.Attribute]:
+ newTargets.append(target)
+ else:
+ log("astTools\tgatherAssignedVars\tWeird Assign Type: " + str(type(target)),"bug")
+ return newTargets
+
+def gatherAssignedVarIds(targets):
+ """Just get the ids of Names"""
+ vars = gatherAssignedVars(targets)
+ return [y.id for y in filter(lambda x : type(x) == ast.Name, vars)]
+
+def getAllAssignedVarIds(a):
+ if not isinstance(a, ast.AST):
+ return []
+ ids = []
+ for child in ast.walk(a):
+ if type(child) == ast.Assign:
+ ids += gatherAssignedVarIds(child.targets)
+ elif type(child) == ast.AugAssign:
+ ids += gatherAssignedVarIds([child.target])
+ elif type(child) == ast.For:
+ ids += gatherAssignedVarIds([child.target])
+ return ids
+
+def getAllAssignedVars(a):
+ if not isinstance(a, ast.AST):
+ return []
+ vars = []
+ for child in ast.walk(a):
+ if type(child) == ast.Assign:
+ vars += gatherAssignedVars(child.targets)
+ elif type(child) == ast.AugAssign:
+ vars += gatherAssignedVars([child.target])
+ elif type(child) == ast.For:
+ vars += gatherAssignedVars([child.target])
+ return vars
+
+def getAllFunctions(a):
+ """Collects all the functions in the given module"""
+ if not isinstance(a, ast.AST):
+ return []
+ functions = []
+ for child in ast.walk(a):
+ if type(child) == ast.FunctionDef:
+ functions.append(child.name)
+ return functions
+
+def getAllImports(a):
+ """Gather all imported module names"""
+ if not isinstance(a, ast.AST):
+ return []
+ imports = []
+ for child in ast.walk(a):
+ if type(child) == ast.Import:
+ for alias in child.names:
+ if alias.name in supportedLibraries:
+ imports.append(alias.asname if alias.asname != None else alias.name)
+ else:
+ log("astTools\tgetAllImports\tUnknown library: " + alias.name, "bug")
+ elif type(child) == ast.ImportFrom:
+ if child.module in supportedLibraries:
+ for alias in child.names: # these are all functions
+ if alias.name in libraryMap[child.module]:
+ imports.append(alias.asname if alias.asname != None else alias.name)
+ else:
+ log("astTools\tgetAllImports\tUnknown import from name: " + \
+ child.module + "," + alias.name, "bug")
+ else:
+ log("astTools\tgetAllImports\tUnknown library: " + child.module, "bug")
+ return imports
+
+def getAllImportStatements(a):
+ if not isinstance(a, ast.AST):
+ return []
+ imports = []
+ for child in ast.walk(a):
+ if type(child) == ast.Import:
+ imports.append(child)
+ elif type(child) == ast.ImportFrom:
+ imports.append(child)
+ return imports
+
+def getAllGlobalNames(a):
+ # Finds all names that can be accessed at the global level in the AST
+ if type(a) != ast.Module:
+ return []
+ names = []
+ for obj in a.body:
+ if type(obj) in [ast.FunctionDef, ast.ClassDef]:
+ names.append(obj.name)
+ elif type(obj) in [ast.Assign, ast.AugAssign]:
+ targets = obj.targets if type(obj) == ast.Assign else [obj.target]
+ for target in obj.targets:
+ if type(target) == ast.Name:
+ names.append(target.id)
+ elif type(target) in [ast.Tuple, ast.List]:
+ for elt in target.elts:
+ if type(elt) == ast.Name:
+ names.append(elt.id)
+ elif type(obj) in [ast.Import, ast.ImportFrom]:
+ for module in obj.names:
+ names.append(module.asname if module.asname != None else module.name)
+ return names
+
+def doBinaryOp(op, l, r):
+ """Perform the given AST binary operation on the values"""
+ top = type(op)
+ if top == ast.Add:
+ return l + r
+ elif top == ast.Sub:
+ return l - r
+ elif top == ast.Mult:
+ return l * r
+ elif top == ast.Div:
+ # Don't bother if this will be a really long float- it won't work properly!
+ # Also, in Python 3 this is floating division, so perform it accordingly.
+ val = 1.0 * l / r
+ if (val * 1e10 % 1.0) != 0:
+ raise Exception("Repeating Float")
+ return val
+ elif top == ast.Mod:
+ return l % r
+ elif top == ast.Pow:
+ return l ** r
+ elif top == ast.LShift:
+ return l << r
+ elif top == ast.RShift:
+ return l >> r
+ elif top == ast.BitOr:
+ return l | r
+ elif top == ast.BitXor:
+ return l ^ r
+ elif top == ast.BitAnd:
+ return l & r
+ elif top == ast.FloorDiv:
+ return l // r
+
+def doUnaryOp(op, val):
+ """Perform the given AST unary operation on the value"""
+ top = type(op)
+ if top == ast.Invert:
+ return ~ val
+ elif top == ast.Not:
+ return not val
+ elif top == ast.UAdd:
+ return val
+ elif top == ast.USub:
+ return -val
+
+def doCompare(op, left, right):
+ """Perform the given AST comparison on the values"""
+ top = type(op)
+ if top == ast.Eq:
+ return left == right
+ elif top == ast.NotEq:
+ return left != right
+ elif top == ast.Lt:
+ return left < right
+ elif top == ast.LtE:
+ return left <= right
+ elif top == ast.Gt:
+ return left > right
+ elif top == ast.GtE:
+ return left >= right
+ elif top == ast.Is:
+ return left is right
+ elif top == ast.IsNot:
+ return left is not right
+ elif top == ast.In:
+ return left in right
+ elif top == ast.NotIn:
+ return left not in right
+
+def num_negate(op):
+ top = type(op)
+ neg = not op.num_negated if hasattr(op, "num_negated") else True
+ if top == ast.Add:
+ newOp = ast.Sub()
+ elif top == ast.Sub:
+ newOp = ast.Add()
+ elif top in [ast.Mult, ast.Div, ast.Mod, ast.Pow, ast.LShift,
+ ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd, ast.FloorDiv]:
+ return None # can't negate this
+ elif top in [ast.Num, ast.Name]:
+ # this is a normal value, so put a - in front of it
+ newOp = ast.UnaryOp(ast.USub(addedNeg=True), op)
+ else:
+ log("astTools\tnum_negate\tUnusual type: " + str(top), "bug")
+ transferMetaData(op, newOp)
+ newOp.num_negated = neg
+ return newOp
+
+def negate(op):
+ """Return the negation of the provided operator"""
+ if op == None:
+ return None
+ top = type(op)
+ neg = not op.negated if hasattr(op, "negated") else True
+ if top == ast.And:
+ newOp = ast.Or()
+ elif top == ast.Or:
+ newOp = ast.And()
+ elif top == ast.Eq:
+ newOp = ast.NotEq()
+ elif top == ast.NotEq:
+ newOp = ast.Eq()
+ elif top == ast.Lt:
+ newOp = ast.GtE()
+ elif top == ast.GtE:
+ newOp = ast.Lt()
+ elif top == ast.Gt:
+ newOp = ast.LtE()
+ elif top == ast.LtE:
+ newOp = ast.Gt()
+ elif top == ast.Is:
+ newOp = ast.IsNot()
+ elif top == ast.IsNot:
+ newOp = ast.Is()
+ elif top == ast.In:
+ newOp = ast.NotIn()
+ elif top == ast.NotIn:
+ newOp = ast.In()
+ elif top == ast.NameConstant and op.value in [True, False]:
+ op.value = not op.value
+ op.negated = neg
+ return op
+ elif top == ast.Compare:
+ if len(op.ops) == 1:
+ op.ops[0] = negate(op.ops[0])
+ op.negated = neg
+ return op
+ else:
+ values = []
+ allOperands = [op.left] + op.comparators
+ for i in range(len(op.ops)):
+ values.append(ast.Compare(allOperands[i], [negate(op.ops[i])],
+ [allOperands[i+1]], multiCompPart=True))
+ newOp = ast.BoolOp(ast.Or(multiCompOp=True), values, multiComp=True)
+ elif top == ast.UnaryOp and type(op.op) == ast.Not and \
+ eventualType(op.operand) == bool: # this can mess things up type-wise
+ return op.operand
+ else:
+ # this is a normal value, so put a not around it
+ newOp = ast.UnaryOp(ast.Not(addedNot=True), op)
+ transferMetaData(op, newOp)
+ newOp.negated = neg
+ return newOp
+
+def couldCrash(a):
+ """Determines whether the given AST could possibly crash"""
+ typeCrashes = True # toggle based on whether you care about potential crashes caused by types
+ if not isinstance(a, ast.AST):
+ return False
+
+ if type(a) == ast.Try:
+ for handler in a.handlers:
+ for child in ast.iter_child_nodes(handler):
+ if couldCrash(child):
+ return True
+ for other in a.orelse:
+ for child in ast.iter_child_nodes(other):
+ if couldCrash(child):
+ return True
+ for line in a.finalbody:
+ for child in ast.iter_child_nodes(line):
+ if couldCrash(child):
+ return True
+ return False
+
+ # If any child could crash, this can crash
+ for child in ast.iter_child_nodes(a):
+ if couldCrash(child):
+ return True
+
+ if type(a) == ast.FunctionDef:
+ argNames = []
+ for arg in a.args.args:
+ if arg.arg in argNames: # conflicting arg names!
+ return True
+ else:
+ argNames.append(arg.arg)
+ if type(a) == ast.Assign:
+ for target in a.targets:
+ if type(target) != ast.Name: # can crash if it's a tuple and we can't unpack the value
+ return True
+ elif type(a) in [ast.For, ast.comprehension]: # check if the target or iter will break things
+ if type(a.target) not in [ast.Name, ast.Tuple, ast.List]:
+ return True
+ elif type(a.target) in [ast.Tuple, ast.List]:
+ for x in a.target.elts:
+ if type(x) != ast.Name:
+ return True
+ elif isIterableType(eventualType(a.iter)):
+ return True
+ elif type(a) == ast.Import:
+ for name in a.names:
+ if name not in supportedLibraries:
+ return True
+ elif type(a) == ast.ImportFrom:
+ if a.module not in supportedLibraries:
+ return True
+ if a.level != None:
+ return True
+ for name in a.names:
+ if name not in libraryMap[a.module]:
+ return True
+ elif type(a) == ast.BinOp:
+ l = eventualType(a.left)
+ r = eventualType(a.right)
+ if type(a.op) == ast.Add:
+ if not ((l == r == str) or (l in [int, float] and r in [int, float])):
+ return typeCrashes
+ elif type(a.op) == ast.Mult:
+ if not ((l == str and r == int) or (l == int and r == str) or \
+ (l in [int, float] and r in [int, float])):
+ return typeCrashes
+ elif type(a.op) in [ast.Sub, ast.LShift, ast.RShift, ast.BitOr, ast.BitXor, ast.BitAnd]:
+ if not (l in [int, float] and r in [int, float]):
+ return typeCrashes
+ elif type(a.op) == ast.Pow:
+ if not ((l in [int, float] and r == int) or \
+ (l in [int, float] and type(a.right) == ast.Num and \
+ type(a.right.n) != complex and \
+ (a.right.n >= 1 or a.right.n == 0 or a.right.n <= -1))):
+ return True
+ else: # ast.Div, ast.FloorDiv, ast.Mod
+ if type(a.right) == ast.Num and a.right.n != 0:
+ if l not in [int, float]:
+ return typeCrashes
+ else:
+ return True # Divide by zero error
+ elif type(a) == ast.UnaryOp:
+ if type(a.op) in [ast.UAdd, ast.USub]:
+ if eventualType(a.operand) not in [int, float]:
+ return typeCrashes
+ elif type(a.op) == ast.Invert:
+ if eventualType(a.operand) != int:
+ return typeCrashes
+ elif type(a) == ast.Compare:
+ if len(a.ops) != len(a.comparators):
+ return True
+ elif type(a.ops[0]) in [ast.In, ast.NotIn]:
+ if not isIterableType(eventualType(a.comparators[0])):
+ return True
+ elif eventualType(a.comparators[0]) in [str, bytes] and eventualType(a.left) not in [str, bytes]:
+ return True
+ elif type(a.ops[0]) in [ast.Lt, ast.LtE, ast.Gt, ast.GtE]:
+ # In Python3, you can't compare different types. BOOOOOO!!
+ firstType = eventualType(a.left)
+ if firstType == None:
+ return True
+ for comp in a.comparators:
+ if eventualType(comp) != firstType:
+ return True
+ elif type(a) == ast.Call:
+ env = [] # TODO: what if the environments aren't imported?
+ # First, gather up the needed variables
+ if type(a.func) == ast.Name:
+ funName = a.func.id
+ if funName not in builtInSafeFunctions:
+ return True
+ funDict = builtInFunctions
+ elif type(a.func) == ast.Attribute:
+ if type(a.func.value) == a.Name and \
+ (not hasattr(a.func.value, "varID")) and \
+ a.func.value.id in supportedLibraries:
+ funName = a.func.attr
+ if funName not in safeLibraryMap(a.func.value.id):
+ return True
+ funDict = libraryMap[a.func.value.id]
+ elif eventualType(a.func.value) == str:
+ funName = a.func.attr
+ if funName not in safeStringFunctions:
+ return True
+ funDict = builtInStringFunctions
+ else: # list and dict are definitely crashable
+ return True
+ else:
+ return True
+
+ if funName in ["max", "min"]:
+ return False # Special functions that have infinite args
+
+ # First, load up the arg types
+ argTypes = []
+ for i in range(len(a.args)):
+ eventual = eventualType(a.args[i])
+ if (eventual == None and typeCrashes):
+ return True
+ argTypes.append(eventual)
+
+ if funDict[funName] != None:
+ for argSet in funDict[funName]: # the given possibilities of arg types
+ if len(argSet) != len(argTypes):
+ continue
+ if not typeCrashes: # If we don't care about types, stop now
+ return False
+
+ for i in range(len(argSet)):
+ if not (argSet[i] == argTypes[i] or issubclass(argTypes[i], argSet[i])):
+ break
+ else: # if all types matched
+ return False
+ return True # Didn't fit any of the options
+ elif type(a) == ast.Subscript: # can only get an index from a string or list
+ return eventualType(a.value) not in [str, list, tuple]
+ elif type(a) == ast.Name:
+ # If it's an undefined variable, it might crash
+ if hasattr(a, "randomVar"):
+ return True
+ elif type(a) == ast.Slice:
+ if a.lower != None and eventualType(a.lower) != int:
+ return True
+ if a.upper != None and eventualType(a.upper) != int:
+ return True
+ if a.step != None and eventualType(a.step) != int:
+ return True
+ elif type(a) in [ast.Raise, ast.Assert, ast.Pass, ast.Break, \
+ ast.Continue, ast.Yield, ast.Attribute, ast.ExtSlice, ast.Index, \
+ ast.Starred]:
+ # All of these cases can definitely crash.
+ return True
+ return False
+
+def eventualType(a):
+ """Get the type the expression will eventually be, if possible
+ The expression might also crash! But we don't care about that here,
+ we'll deal with it elsewhere.
+ Returning 'None' means that we cannot say at the moment"""
+ if type(a) in builtInTypes:
+ return type(a)
+ if not isinstance(a, ast.AST):
+ return None
+
+ elif type(a) == ast.BoolOp:
+ # In Python, it's the type of all the values in it
+ # this may work differently in other languages
+ t = eventualType(a.values[0])
+ for i in range(1, len(a.values)):
+ if eventualType(a.values[i]) != t:
+ return None
+ return t
+ elif type(a) == ast.BinOp:
+ l = eventualType(a.left)
+ r = eventualType(a.right)
+ # It is possible to add/multiply sequences
+ if type(a.op) in [ast.Add, ast.Mult]:
+ if isIterableType(l):
+ return l
+ elif isIterableType(r):
+ return r
+ elif l == float or r == float:
+ return float
+ elif l == int and r == int:
+ return int
+ return None
+ elif type(a.op) == ast.Div:
+ return float # always a float now
+ # For others, check if we know whether it's a float or an int
+ elif type(a.op) in [ast.FloorDiv, ast.LShift, ast.RShift, ast.BitOr,
+ ast.BitAnd, ast.BitXor]:
+ return int
+ elif float in [l, r]:
+ return float
+ elif l == int and r == int:
+ return int
+ else:
+ return None # Otherwise, it could be a float- we don't know
+ elif type(a) == ast.UnaryOp:
+ if type(a.op) == ast.Invert:
+ return int
+ elif type(a.op) in [ast.UAdd, ast.USub]:
+ return eventualType(a.operand)
+ else: # Not op
+ return bool
+ elif type(a) == ast.Lambda:
+ return function
+ elif type(a) == ast.IfExp:
+ l = eventualType(a.body)
+ r = eventualType(a.orelse)
+ if l == r:
+ return l
+ else:
+ return None
+ elif type(a) in [ast.Dict, ast.DictComp]:
+ return dict
+ elif type(a) in [ast.Set, ast.SetComp]:
+ return set
+ elif type(a) in [ast.List, ast.ListComp]:
+ return list
+ elif type(a) == ast.GeneratorExp:
+ return None # can't represent a generator
+ elif type(a) == ast.Yield:
+ return None # we don't know
+ elif type(a) == ast.Compare:
+ return bool
+ elif type(a) == ast.Call:
+ # Go through our different sets of known functions to see if we know the type
+ argTypes = [eventualType(x) for x in a.args]
+ if type(a.func) == ast.Name:
+ funDict = builtInFunctions
+ funName = a.func.id
+ elif type(a.func) == ast.Attribute:
+ # TODO: get a better solution than this
+ funName = a.func.attr
+ if type(a.func.value) == ast.Name and \
+ (not hasattr(a.func.value, "varID")) and \
+ a.func.value.id in supportedLibraries:
+ funDict = libraryDictMap[a.func.value.id]
+ if a.func.value.id in ["string", "str", "list", "dict"] and len(argTypes) > 0:
+ argTypes.pop(0) # get rid of the first string arg
+ elif eventualType(a.func.value) == str:
+ funDict = builtInStringFunctions
+ elif eventualType(a.func.value) == list:
+ funDict = builtInListFunctions
+ elif eventualType(a.func.value) == dict:
+ funDict = builtInDictFunctions
+ else:
+ return None
+ else:
+ return None
+
+ if funName in ["max", "min"]:
+ # If all args are the same type, that's our type
+ uniqueTypes = set(argTypes)
+ if len(uniqueTypes) == 1:
+ return uniqueTypes.pop()
+ return None
+
+ if funName in funDict and funDict[funName] != None:
+ possibleTypes = []
+ for argSet in funDict[funName]:
+ if len(argSet) == len(argTypes):
+ # All types must match!
+ for i in range(len(argSet)):
+ if argSet[i] == None or argTypes[i] == None: # We don't know, but that's okay
+ continue
+ if not (argSet[i] == argTypes[i] or (issubclass(argTypes[i], argSet[i]))):
+ break
+ else:
+ possibleTypes.append(funDict[funName][argSet])
+ possibleTypes = set(possibleTypes)
+ if len(possibleTypes) == 1: # If there's only one possibility, that's our type!
+ return possibleTypes.pop()
+ return None
+ elif type(a) in [ast.Str, ast.Bytes]:
+ if containsTokenStepString(a):
+ return None
+ return str
+ elif type(a) == ast.Num:
+ return type(a.n)
+ elif type(a) == ast.Attribute:
+ return None # we have no way of knowing
+ elif type(a) == ast.Subscript:
+ # We're slicing the object, so the type will stay the same
+ t = eventualType(a.value)
+ if t == None:
+ return None
+ elif t == str:
+ return str # indexing a string
+ elif t in [list, tuple]:
+ if type(a.slice) == ast.Slice:
+ return t
+ # Otherwise, we need the types of the elements
+ if type(a.value) in [ast.List, ast.Tuple]:
+ if len(a.value.elts) == 0:
+ return None # We don't know
+ else:
+ eltType = eventualType(a.value.elts[0])
+ for elt in a.value.elts:
+ if eventualType(elt) != eltType:
+ return None # Disagreement!
+ return eltType
+ elif t in [dict, int]:
+ return None
+ else:
+ log("astTools\teventualType\tUnknown type in subscript: " + str(t), "bug")
+ return None # We can't know for now...
+ elif type(a) == ast.NameConstant:
+ if a.value == True or a.value == False:
+ return bool
+ elif a.value == None:
+ return type(None)
+ return None
+ elif type(a) == ast.Name:
+ if hasattr(a, "type"): # If it's a variable we categorized
+ return a.type
+ return None
+ elif type(a) == ast.Tuple:
+ return tuple
+ elif type(a) == ast.Starred:
+ return None # too complicated
+ else:
+ log("astTools\teventualType\tUnimplemented type " + str(type(a)), "bug")
+ return None
+
+def depthOfAST(a):
+ """Determine the depth of the AST"""
+ if not isinstance(a, ast.AST):
+ return 0
+ m = 0
+ for child in ast.iter_child_nodes(a):
+ tmp = depthOfAST(child)
+ if tmp > m:
+ m = tmp
+ return m + 1
+
+def compareASTs(a, b, checkEquality=False):
+ """A comparison function for ASTs"""
+ # None before others
+ if a == b == None:
+ return 0
+ elif a == None or b == None:
+ return -1 if a == None else 1
+
+ if type(a) == type(b) == list:
+ if len(a) != len(b):
+ return len(a) - len(b)
+ for i in range(len(a)):
+ r = compareASTs(a[i], b[i], checkEquality=checkEquality)
+ if r != 0:
+ return r
+ return 0
+
+ # AST before primitive
+ if (not isinstance(a, ast.AST)) and (not isinstance(b, ast.AST)):
+ if type(a) != type(b):
+ builtins = [bool, int, float, str, bytes, complex]
+ if type(a) not in builtins or type(b) not in builtins:
+ log("MISSING BUILT-IN TYPE: " + str(type(a)) + "," + str(type(b)), "bug")
+ return builtins.index(type(a)) - builtins.index(type(b))
+ return cmp(a, b)
+ elif (not isinstance(a, ast.AST)) or (not isinstance(b, ast.AST)):
+ return -1 if isinstance(a, ast.AST) else 1
+
+ # Order by differing types
+ if type(a) != type(b):
+ # Here is a brief ordering of types that we care about
+ blehTypes = [ ast.Load, ast.Store, ast.Del, ast.AugLoad, ast.AugStore, ast.Param ]
+ if type(a) in blehTypes and type(b) in blehTypes:
+ return 0
+ elif type(a) in blehTypes or type(b) in blehTypes:
+ return -1 if type(a) in blehTypes else 1
+
+ types = [ ast.Module, ast.Interactive, ast.Expression, ast.Suite,
+
+ ast.Break, ast.Continue, ast.Pass, ast.Global,
+ ast.Expr, ast.Assign, ast.AugAssign, ast.Return,
+ ast.Assert, ast.Delete, ast.If, ast.For, ast.While,
+ ast.With, ast.Import, ast.ImportFrom, ast.Raise,
+ ast.Try, ast.FunctionDef,
+ ast.ClassDef,
+
+ ast.BinOp, ast.BoolOp, ast.Compare, ast.UnaryOp,
+ ast.DictComp, ast.ListComp, ast.SetComp, ast.GeneratorExp,
+ ast.Yield, ast.Lambda, ast.IfExp, ast.Call, ast.Subscript,
+ ast.Attribute, ast.Dict, ast.List, ast.Tuple,
+ ast.Set, ast.Name, ast.Str, ast.Bytes, ast.Num,
+ ast.NameConstant, ast.Starred,
+
+ ast.Ellipsis, ast.Index, ast.Slice, ast.ExtSlice,
+
+ ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div,
+ ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr,
+ ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert, ast.Not,
+ ast.UAdd, ast.USub, ast.Eq, ast.NotEq, ast.Lt, ast.LtE,
+ ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In, ast.NotIn,
+
+ ast.alias, ast.keyword, ast.arguments, ast.arg, ast.comprehension,
+ ast.ExceptHandler, ast.withitem
+ ]
+ if (type(a) not in types) or (type(b) not in types):
+ log("astTools\tcompareASTs\tmissing type:" + str(type(a)) + "," + str(type(b)), "bug")
+ return 0
+ return types.index(type(a)) - types.index(type(b))
+
+ # Then, more complex expressions- but don't bother with this if we're just checking equality
+ if not checkEquality:
+ ad = depthOfAST(a)
+ bd = depthOfAST(b)
+ if ad != bd:
+ return bd - ad
+
+ # NameConstants are special
+ if type(a) == ast.NameConstant:
+ if a.value == None or b.value == None:
+ return 1 if a.value != None else (0 if b.value == None else -1) # short and works
+
+ if a.value in [True, False] or b.value in [True, False]:
+ return 1 if a.value not in [True, False] else (cmp(a.value, b.value) if b.value in [True, False] else -1)
+
+ if type(a) == ast.Name:
+ return cmp(a.id, b.id)
+
+ # Operations and attributes are all ok
+ elif type(a) in [ ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div,
+ ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr,
+ ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert,
+ ast.Not, ast.UAdd, ast.USub, ast.Eq, ast.NotEq, ast.Lt,
+ ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In,
+ ast.NotIn, ast.Load, ast.Store, ast.Del, ast.AugLoad,
+ ast.AugStore, ast.Param, ast.Ellipsis, ast.Pass,
+ ast.Break, ast.Continue
+ ]:
+ return 0
+
+ # Now compare based on the attributes in the identical types
+ for attr in a._fields:
+ r = compareASTs(getattr(a, attr), getattr(b, attr), checkEquality=checkEquality)
+ if r != 0:
+ return r
+ # If all attributes are identical, they're equal
+ return 0
+
+def deepcopyList(l):
+ """Deepcopy of a list"""
+ if l == None:
+ return None
+ if isinstance(l, ast.AST):
+ return deepcopy(l)
+ if type(l) != list:
+ log("astTools\tdeepcopyList\tNot a list: " + str(type(l)), "bug")
+ return copy.deepcopy(l)
+
+ newList = []
+ for line in l:
+ newList.append(deepcopy(line))
+ return newList
+
+def deepcopy(a):
+ """Let's try to keep this as quick as possible"""
+ if a == None:
+ return None
+ if type(a) == list:
+ return deepcopyList(a)
+ elif type(a) in [int, float, str, bool]:
+ return a
+ if not isinstance(a, ast.AST):
+ log("astTools\tdeepcopy\tNot an AST: " + str(type(a)), "bug")
+ return copy.deepcopy(a)
+
+ g = a.global_id if hasattr(a, "global_id") else None
+ cp = None
+ # Objects without lineno, col_offset
+ if type(a) in [ ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div,
+ ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr,
+ ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert,
+ ast.Not, ast.UAdd, ast.USub, ast.Eq, ast.NotEq, ast.Lt,
+ ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot, ast.In,
+ ast.NotIn, ast.Load, ast.Store, ast.Del, ast.AugLoad,
+ ast.AugStore, ast.Param
+ ]:
+ return a
+ elif type(a) == ast.Module:
+ cp = ast.Module(deepcopyList(a.body))
+ elif type(a) == ast.Interactive:
+ cp = ast.Interactive(deepcopyList(a.body))
+ elif type(a) == ast.Expression:
+ cp = ast.Expression(deepcopy(a.body))
+ elif type(a) == ast.Suite:
+ cp = ast.Suite(deepcopyList(a.body))
+
+ elif type(a) == ast.FunctionDef:
+ cp = ast.FunctionDef(a.name, deepcopy(a.args), deepcopyList(a.body),
+ deepcopyList(a.decorator_list), deepcopy(a.returns))
+ elif type(a) == ast.ClassDef:
+ cp = ast.ClassDef(a.name, deepcopyList(a.bases), deepcopyList(a.keywords), deepcopyList(a.body),
+ deepcopyList(a.decorator_list))
+ elif type(a) == ast.Return:
+ cp = ast.Return(deepcopy(a.value))
+ elif type(a) == ast.Delete:
+ cp = ast.Delete(deepcopyList(a.targets))
+ elif type(a) == ast.Assign:
+ cp = ast.Assign(deepcopyList(a.targets), deepcopy(a.value))
+ elif type(a) == ast.AugAssign:
+ cp = ast.AugAssign(deepcopy(a.target), deepcopy(a.op),
+ deepcopy(a.value))
+ elif type(a) == ast.For:
+ cp = ast.For(deepcopy(a.target), deepcopy(a.iter),
+ deepcopyList(a.body), deepcopyList(a.orelse))
+ elif type(a) == ast.While:
+ cp = ast.While(deepcopy(a.test), deepcopyList(a.body),
+ deepcopyList(a.orelse))
+ elif type(a) == ast.If:
+ cp = ast.If(deepcopy(a.test), deepcopyList(a.body),
+ deepcopyList(a.orelse))
+ elif type(a) == ast.With:
+ cp = ast.With(deepcopyList(a.items),deepcopyList(a.body))
+ elif type(a) == ast.Raise:
+ cp = ast.Raise(deepcopy(a.exc), deepcopy(a.cause))
+ elif type(a) == ast.Try:
+ cp = ast.Try(deepcopyList(a.body), deepcopyList(a.handlers),
+ deepcopyList(a.orelse), deepcopyList(a.finalbody))
+ elif type(a) == ast.Assert:
+ cp = ast.Assert(deepcopy(a.test), deepcopy(a.msg))
+ elif type(a) == ast.Import:
+ cp = ast.Import(deepcopyList(a.names))
+ elif type(a) == ast.ImportFrom:
+ cp = ast.ImportFrom(a.module, deepcopyList(a.names), a.level)
+ elif type(a) == ast.Global:
+ cp = ast.Global(a.names[:])
+ elif type(a) == ast.Expr:
+ cp = ast.Expr(deepcopy(a.value))
+ elif type(a) == ast.Pass:
+ cp = ast.Pass()
+ elif type(a) == ast.Break:
+ cp = ast.Break()
+ elif type(a) == ast.Continue:
+ cp = ast.Continue()
+
+ elif type(a) == ast.BoolOp:
+ cp = ast.BoolOp(a.op, deepcopyList(a.values))
+ elif type(a) == ast.BinOp:
+ cp = ast.BinOp(deepcopy(a.left), a.op, deepcopy(a.right))
+ elif type(a) == ast.UnaryOp:
+ cp = ast.UnaryOp(a.op, deepcopy(a.operand))
+ elif type(a) == ast.Lambda:
+ cp = ast.Lambda(deepcopy(a.args), deepcopy(a.body))
+ elif type(a) == ast.IfExp:
+ cp = ast.IfExp(deepcopy(a.test), deepcopy(a.body), deepcopy(a.orelse))
+ elif type(a) == ast.Dict:
+ cp = ast.Dict(deepcopyList(a.keys), deepcopyList(a.values))
+ elif type(a) == ast.Set:
+ cp = ast.Set(deepcopyList(a.elts))
+ elif type(a) == ast.ListComp:
+ cp = ast.ListComp(deepcopy(a.elt), deepcopyList(a.generators))
+ elif type(a) == ast.SetComp:
+ cp = ast.SetComp(deepcopy(a.elt), deepcopyList(a.generators))
+ elif type(a) == ast.DictComp:
+ cp = ast.DictComp(deepcopy(a.key), deepcopy(a.value),
+ deepcopyList(a.generators))
+ elif type(a) == ast.GeneratorExp:
+ cp = ast.GeneratorExp(deepcopy(a.elt), deepcopyList(a.generators))
+ elif type(a) == ast.Yield:
+ cp = ast.Yield(deepcopy(a.value))
+ elif type(a) == ast.Compare:
+ cp = ast.Compare(deepcopy(a.left), a.ops[:],
+ deepcopyList(a.comparators))
+ elif type(a) == ast.Call:
+ cp = ast.Call(deepcopy(a.func), deepcopyList(a.args), deepcopyList(a.keywords))
+ elif type(a) == ast.Num:
+ cp = ast.Num(a.n)
+ elif type(a) == ast.Str:
+ cp = ast.Str(a.s)
+ elif type(a) == ast.Bytes:
+ cp = ast.Bytes(a.s)
+ elif type(a) == ast.NameConstant:
+ cp = ast.NameConstant(a.value)
+ elif type(a) == ast.Attribute:
+ cp = ast.Attribute(deepcopy(a.value), a.attr, a.ctx)
+ elif type(a) == ast.Subscript:
+ cp = ast.Subscript(deepcopy(a.value), deepcopy(a.slice), a.ctx)
+ elif type(a) == ast.Name:
+ cp = ast.Name(a.id, a.ctx)
+ elif type(a) == ast.List:
+ cp = ast.List(deepcopyList(a.elts), a.ctx)
+ elif type(a) == ast.Tuple:
+ cp = ast.Tuple(deepcopyList(a.elts), a.ctx)
+ elif type(a) == ast.Starred:
+ cp = ast.Starred(deepcopy(a.value), a.ctx)
+
+ elif type(a) == ast.Slice:
+ cp = ast.Slice(deepcopy(a.lower), deepcopy(a.upper), deepcopy(a.step))
+ elif type(a) == ast.ExtSlice:
+ cp = ast.ExtSlice(deepcopyList(a.dims))
+ elif type(a) == ast.Index:
+ cp = ast.Index(deepcopy(a.value))
+
+ elif type(a) == ast.comprehension:
+ cp = ast.comprehension(deepcopy(a.target), deepcopy(a.iter),
+ deepcopyList(a.ifs))
+ elif type(a) == ast.ExceptHandler:
+ cp = ast.ExceptHandler(deepcopy(a.type), a.name, deepcopyList(a.body))
+ elif type(a) == ast.arguments:
+ cp = ast.arguments(deepcopyList(a.args), deepcopy(a.vararg),
+ deepcopyList(a.kwonlyargs), deepcopyList(a.kw_defaults),
+ deepcopy(a.kwarg), deepcopyList(a.defaults))
+ elif type(a) == ast.arg:
+ cp = ast.arg(a.arg, deepcopy(a.annotation))
+ elif type(a) == ast.keyword:
+ cp = ast.keyword(a.arg, deepcopy(a.value))
+ elif type(a) == ast.alias:
+ cp = ast.alias(a.name, a.asname)
+ elif type(a) == ast.withitem:
+ cp = ast.withitem(deepcopy(a.context_expr), deepcopy(a.optional_vars))
+ else:
+ log("astTools\tdeepcopy\tNot implemented: " + str(type(a)), "bug")
+ cp = copy.deepcopy(a)
+
+ transferMetaData(a, cp)
+ return cp
+
+def exportToJson(a):
+ """Export the ast to json format"""
+ if a == None:
+ return "null"
+ elif type(a) in [int, float]:
+ return str(a)
+ elif type(a) == str:
+ return '"' + a + '"'
+ elif not isinstance(a, ast.AST):
+ log("astTools\texportToJson\tMissing type: " + str(type(a)), "bug")
+
+ s = "{\n"
+ if type(a) in astNames:
+ s += '"' + astNames[type(a)] + '": {\n'
+ for field in a._fields:
+ s += '"' + field + '": '
+ value = getattr(a, field)
+ if type(value) == list:
+ s += "["
+ for item in value:
+ s += exportToJson(item) + ", "
+ if len(value) > 0:
+ s = s[:-2]
+ s += "]"
+ else:
+ s += exportToJson(value)
+ s += ", "
+ if len(a._fields) > 0:
+ s = s[:-2]
+ s += "}"
+ else:
+ log("astTools\texportToJson\tMissing AST type: " + str(type(a)), "bug")
+ s += "}"
+ return s
+
+### ITAP/Canonicalization Functions ###
+
+def isTokenStepString(s):
+ """Determine whether this is a placeholder string"""
+ if len(s) < 2:
+ return False
+ return s[0] == "~" and s[-1] == "~"
+
+def getParentFunction(s):
+ underscoreSep = s.split("_")
+ if len(underscoreSep) == 1:
+ return None
+ result = "_".join(underscoreSep[1:])
+ if result == "newvar" or result == "global":
+ return None
+ return result
+
+def isAnonVariable(s):
+ """Specificies whether the given string is an anonymized variable name"""
+ preUnderscore = s.split("_")[0] # the part before the function name
+ return len(preUnderscore) > 1 and \
+ preUnderscore[0] in ["g", "p", "v", "r", "n", "z"] and \
+ preUnderscore[1:].isdigit()
+
+def isDefault(a):
+ """Our programs have a default setting of return 42, so we should detect that"""
+ if type(a) == ast.Module and len(a.body) == 1:
+ a = a.body[0]
+ else:
+ return False
+
+ if type(a) != ast.FunctionDef:
+ return False
+
+ if len(a.body) == 0:
+ return True
+ elif len(a.body) == 1:
+ if type(a.body[0]) == ast.Return:
+ if a.body[0].value == None or \
+ type(a.body[0].value) == ast.Num and a.body[0].value.n == 42:
+ return True
+ return False
+
+def transferMetaData(a, b):
+ """Transfer the metadata of a onto b"""
+ properties = [ "global_id", "second_global_id", "lineno", "col_offset",
+ "originalId", "varID", "variableGlobalId",
+ "randomVar", "propagatedVariable", "loadedVariable", "dontChangeName",
+ "reversed", "negated", "inverted",
+ "augAssignVal", "augAssignBinOp",
+ "combinedConditional", "combinedConditionalOp",
+ "multiComp", "multiCompPart", "multiCompMiddle", "multiCompOp",
+ "addedNot", "addedNotOp", "addedOther", "addedOtherOp", "addedNeg",
+ "collapsedExpr", "removedLines",
+ "helperVar", "helperReturnVal", "helperParamAssign", "helperReturnAssign",
+ "orderedBinOp", "typeCastFunction", "moved_line" ]
+ for prop in properties:
+ if hasattr(a, prop):
+ setattr(b, prop, getattr(a, prop))
+
+def assignPropertyToAll(a, prop):
+ """Assign the provided property to all children"""
+ if type(a) == list:
+ for child in a:
+ assignPropertyToAll(child, prop)
+ elif isinstance(a, ast.AST):
+ for node in ast.walk(a):
+ setattr(node, prop, True)
+
+def removePropertyFromAll(a, prop):
+ if type(a) == list:
+ for child in a:
+ removePropertyFromAll(child, prop)
+ elif isinstance(a, ast.AST):
+ for node in ast.walk(a):
+ if hasattr(node, prop):
+ delattr(node, prop)
+
+def containsTokenStepString(a):
+ """This is used to keep token-level hint chaining from breaking."""
+ if not isinstance(a, ast.AST):
+ return False
+
+ for node in ast.walk(a):
+ if type(node) == ast.Str and isTokenStepString(node.s):
+ return True
+ return False
+
+def applyVariableMap(a, variableMap):
+ if not isinstance(a, ast.AST):
+ return a
+ if type(a) == ast.Name:
+ if a.id in variableMap:
+ a.id = variableMap[a.id]
+ elif type(a) in [ast.FunctionDef, ast.ClassDef]:
+ if a.name in variableMap:
+ a.name = variableMap[a.name]
+ return applyToChildren(a, lambda x : applyVariableMap(x, variableMap))
+
+def applyHelperMap(a, helperMap):
+ if not isinstance(a, ast.AST):
+ return a
+ if type(a) == ast.Name:
+ if a.id in helperMap:
+ a.id = helperMap[a.id]
+ elif type(a) == ast.FunctionDef:
+ if a.name in helperMap:
+ a.name = helperMap[a.name]
+ return applyToChildren(a, lambda x : applyHelperMap(x, helperMap))
+
+
+def astFormat(x, gid=None):
+ """Given a value, turn it into an AST if it's a constant; otherwise, leave it alone."""
+ if type(x) in [int, float, complex]:
+ return ast.Num(x)
+ elif type(x) == bool or x == None:
+ return ast.NameConstant(x)
+ elif type(x) == type:
+ types = { bool : "bool", int : "int", float : "float",
+ complex : "complex", str : "str", bytes : "bytes", unicode : "unicode",
+ list : "list", tuple : "tuple", dict : "dict" }
+ return ast.Name(types[x], ast.Load())
+ elif type(x) == str: # str or unicode
+ return ast.Str(x)
+ elif type(x) == bytes:
+ return ast.Bytes(x)
+ elif type(x) == list:
+ elts = [astFormat(val) for val in x]
+ return ast.List(elts, ast.Load())
+ elif type(x) == dict:
+ keys = []
+ vals = []
+ for key in x:
+ keys.append(astFormat(key))
+ vals.append(astFormat(x[key]))
+ return ast.Dict(keys, vals)
+ elif type(x) == tuple:
+ elts = [astFormat(val) for val in x]
+ return ast.Tuple(elts, ast.Load())
+ elif type(x) == set:
+ elts = [astFormat(val) for val in x]
+ if len(elts) == 0: # needs to be a call instead
+ return ast.Call(ast.Name("set", ast.Load()), [], [])
+ else:
+ return ast.Set(elts)
+ elif type(x) == slice:
+ return ast.Slice(astFormat(x.start), astFormat(x.stop), astFormat(x.step))
+ elif isinstance(x, ast.AST):
+ return x # Do not change if it's not constant!
+ else:
+ log("astTools\tastFormat\t" + str(type(x)) + "," + str(x),"bug")
+ return None
+
+def basicFormat(x):
+ """Given an AST, turn it into its value if it's constant; otherwise, leave it alone"""
+ if type(x) == ast.Num:
+ return x.n
+ elif type(x) == ast.NameConstant:
+ return x.value
+ elif type(x) == ast.Str:
+ return x.s
+ elif type(x) == ast.Bytes:
+ return x.s
+ return x # Do not change if it's not a constant!
+
+def structureTree(a):
+ if type(a) == list:
+ for i in range(len(a)):
+ a[i] = structureTree(a[i])
+ return a
+ elif not isinstance(a, ast.AST):
+ return a
+ else:
+ if type(a) == ast.FunctionDef:
+ a.name = "~name~"
+ a.args = structureTree(a.args)
+ a.body = structureTree(a.body)
+ a.decorator_list = structureTree(a.decorator_list)
+ a.returns = structureTree(a.returns)
+ elif type(a) == ast.ClassDef:
+ a.name = "~name~"
+ a.bases = structureTree(a.bases)
+ a.keywords = structureTree(a.keywords)
+ a.body = structureTree(a.body)
+ a.decorator_list = structureTree(a.decorator_list)
+ elif type(a) == ast.AugAssign:
+ a.target = structureTree(a.target)
+ a.op = ast.Str("~op~")
+ a.value = structureTree(a.value)
+ elif type(a) == ast.Import:
+ a.names = [ast.Str("~module~")]
+ elif type(a) == ast.ImportFrom:
+ a.module = "~module~"
+ a.names = [ast.Str("~names~")]
+ elif type(a) == ast.Global:
+ a.names = ast.Str("~var~")
+ elif type(a) == ast.BoolOp:
+ a.op = ast.Str("~op~")
+ a.values = structureTree(a.values)
+ elif type(a) == ast.BinOp:
+ a.op = ast.Str("~op~")
+ a.left = structureTree(a.left)
+ a.right = structureTree(a.right)
+ elif type(a) == ast.UnaryOp:
+ a.op = ast.Str("~op~")
+ a.operand = structureTree(a.operand)
+ elif type(a) == ast.Dict:
+ return ast.Str("~dictionary~")
+ elif type(a) == ast.Set:
+ return ast.Str("~set~")
+ elif type(a) == ast.Compare:
+ a.ops = [ast.Str("~op~")]*len(a.ops)
+ a.left = structureTree(a.left)
+ a.comparators = structureTree(a.comparators)
+ elif type(a) == ast.Call:
+ # leave the function alone
+ a.args = structureTree(a.args)
+ a.keywords = structureTree(a.keywords)
+ elif type(a) == ast.Num:
+ return ast.Str("~number~")
+ elif type(a) == ast.Str:
+ return ast.Str("~string~")
+ elif type(a) == ast.Bytes:
+ return ast.Str("~bytes~")
+ elif type(a) == ast.Attribute:
+ a.value = structureTree(a.value)
+ elif type(a) == ast.Name:
+ a.id = "~var~"
+ elif type(a) == ast.List:
+ return ast.Str("~list~")
+ elif type(a) == ast.Tuple:
+ return ast.Str("~tuple~")
+ elif type(a) in [ast.And, ast.Or, ast.Add, ast.Sub, ast.Mult, ast.Div,
+ ast.Mod, ast.Pow, ast.LShift, ast.RShift, ast.BitOr,
+ ast.BitXor, ast.BitAnd, ast.FloorDiv, ast.Invert,
+ ast.Not, ast.UAdd, ast.USub, ast.Eq, ast.NotEq,
+ ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.Is, ast.IsNot,
+ ast.In, ast.NotIn ]:
+ return ast.Str("~op~")
+ elif type(a) == ast.arguments:
+ a.args = structureTree(a.args)
+ a.vararg = ast.Str("~arg~") if a.vararg != None else None
+ a.kwonlyargs = structureTree(a.kwonlyargs)
+ a.kw_defaults = structureTree(a.kw_defaults)
+ a.kwarg = ast.Str("~keyword~") if a.kwarg != None else None
+ a.defaults = structureTree(a.defaults)
+ elif type(a) == ast.arg:
+ a.arg = "~arg~"
+ a.annotation = structureTree(a.annotation)
+ elif type(a) == ast.keyword:
+ a.arg = "~keyword~"
+ a.value = structureTree(a.value)
+ elif type(a) == ast.alias:
+ a.name = "~name~"
+ a.asname = "~asname~" if a.asname != None else None
+ else:
+ for field in a._fields:
+ setattr(a, field, structureTree(getattr(a, field)))
+ return a
+
+
+