import ast import collections from itertools import chain, combinations, product import re import sys from .asttokens import asttokens from .tree import Tree def make_tree(code): def make(node, parent=None): if hasattr(node, 'first_token'): data = {'start': node.first_token.startpos, 'end': node.last_token.endpos, 'ast': node} else: data = {'start': parent.first_token.startpos, 'end': parent.last_token.endpos, 'ast': parent} if not hasattr(node, '_fields'): return Tree(str(node), data=data) children = [] for field, value in ast.iter_fields(node): if value is None: continue if isinstance(value, ast.AST): if not value._fields: child = Tree(field, [Tree(value.__class__.__name__)], data=data) else: child = Tree(field, [make(value, node)], data=data) elif isinstance(value, list): child = Tree(field, [make(v, node) for v in value], data=data) else: child = Tree(field, [make(value, node)], data=data) children.append(child) return Tree(node.__class__.__name__, children, data=data) pytree = asttokens.ASTTokens(code, parse=True).tree return make(pytree) # induce a pattern from the selected nodes in the tree def make_pattern(tree, selected): label = tree.label # found a node, do not recurse further if any(n is tree for n in selected): if label == 'Name': # replace all variable names with 'VAR' return Tree(label, [Tree('id', [Tree('VAR')]), tree[1]]) elif label == 'arg': # replace all variable names with 'VAR' (in function arguments) return Tree(label, [Tree('arg', [Tree('VAR')])]) elif label == 'Str': # omit string value return Tree(label) elif label == 'Call': # omit arguments, keep the original function name if tree[0][0].label == 'Attribute': # method call return Tree(label, [Tree('func', [Tree('Attribute', tree[0][0][1:])])]) else: return Tree(label, [tree[0]]) else: return tree subpats = [make_pattern(child, selected) for child in tree] if all(s is None for s in subpats): # no selected nodes found in this branch, move along return None # at this point, one or more children of tree are ancestors to # some selected nodes # extra nodes to include in patterns, depending on node label if label == 'FunctionDef': # include function name pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None]) elif label == 'Attribute': # include attribute name children = [] if subpats[0] is not None: children += [subpats[0]] children += tree[1:] pat = Tree(label, children) elif label == 'Call': # include function/method name children = [] if subpats[0] is not None: children += [subpats[0]] else: func = tree[0] if func[0].label == 'Attribute': # method call: include attribute (method) name children += [Tree('func', [Tree('Attribute', func[0][1:])])] else: # function call: include function name children += [func] children += [s for s in subpats[1:] if s is not None] pat = Tree(label, children) elif label == 'args': # (args arg1 … argN) # include at least an 'arg' node in the pattern for each # argument, allowing us to distinguish between subpatterns # found at different argument positions pat = Tree(label, [s if s is not None else Tree('arg') for s in subpats]) elif label == 'UnaryOp': # store the operator pat = Tree(label, [tree[0], subpats[1]]) elif label == 'BinOp' or label == 'Compare': # store the operator and left and/or right side children = [] if subpats[0] is not None: children += [subpats[0]] children += [tree[1]] if subpats[2] is not None: children += [subpats[2]] pat = Tree(label, children) else: # store children if they match some part of the pattern pat = Tree(label, [s for s in subpats if s is not None]) return pat def get_variables(tree): variables = collections.defaultdict(list) for node, parent in tree.subtrees_with_parents(): if node.label == 'Name' and parent.label != 'func': name = node[0][0].label variables[name].append(node) elif node.label == 'FunctionDef': # AST form: (FunctionDef name (args (arguments (args arg0 arg1 … argN)))) for arg in node[1][0][0]: if arg: name = arg[0][0].label variables[name].append(arg) elif node.label == 'Lambda': # AST form: (Lambda (args (arguments (args (arg0 arg1 … argN))))) for arg in node[0][0][0]: if arg: name = arg[0][0].label variables[name].append(arg) return variables def get_literals(tree): return [n for n in tree.subtrees() if n.label in {'Num', 'Str', 'NameConstant'}] def get_calls(tree): return [n for n in tree.subtrees() if n.label == 'Call'] # non-statement expressions short_range = { 'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp', 'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List' } def get_patterns(code): tree = make_tree(code) variables = get_variables(tree) def subsets(iterable): s = list(iterable) return chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1)) def patterns_with_variables(): for var, nodes in variables.items(): # singletons if len(nodes) == 1: pat = make_pattern(tree, nodes) if pat: yield pat, nodes # pairs of nodes with the same variable for selected in chain( combinations(nodes, 2)): pat = make_pattern(tree, selected) if pat: yield pat, selected # yield patterns for variable-literal / literal-literal pairs # yield patterns for singleton literals # (only within certain statements and expressions) def patterns_with_literals(node): if not node: return if node.label in short_range: vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)] literals = get_literals(node) for selected in chain( subsets(literals), product(literals, vars)): pat = make_pattern(tree, selected) if pat: yield pat, selected else: for child in node: yield from patterns_with_literals(child) # yield patterns describing function calls def patterns_with_calls(tree): for node in get_calls(tree): pat = make_pattern(tree, [node]) if pat: yield pat, [node] # yield patterns describing control structures def patterns_with_control_structures(node, path=None): if path is None: path = [] if node.label in ['FunctionDef', 'If', 'Else', 'For', 'While', 'Break', 'Continue', 'Pass', 'With', 'Raise', 'Try', 'Return']: pat = Tree(node.label) for parent in path[::-1]: pat = Tree(parent.label, [pat]) yield pat, [node] path += [node] for child in node: yield from patterns_with_control_structures(child, path) yield from patterns_with_variables() yield from patterns_with_calls(tree) yield from patterns_with_literals(tree) yield from patterns_with_control_structures(tree) def get_attributes(programs): # extract attributes from training data patterns = collections.defaultdict(list) for code in programs: for pat, nodes in get_patterns(code): patterns[pat] += [code] attrs = collections.OrderedDict() for pat, progs in sorted(patterns.items(), key=lambda x: len(x[1]), reverse=True): if len(progs) < 5: break attrs['regex-{}'.format(len(attrs))] = {'desc': pat, 'programs': progs} return attrs