import ast import collections from itertools import chain, combinations, product import re import sys from .asttokens import asttokens from .tree import Tree def make_tree(code): def make(node, parent=None): if hasattr(node, 'first_token'): data = {'start': node.first_token.startpos, 'end': node.last_token.endpos, 'ast': node} else: data = {'start': parent.first_token.startpos, 'end': parent.last_token.endpos, 'ast': parent} if not hasattr(node, '_fields'): return Tree(str(node), data=data) children = [] for field, value in ast.iter_fields(node): if value is None: continue if isinstance(value, ast.AST): if not value._fields: child = Tree(field, [Tree(value.__class__.__name__)], data=data) else: child = Tree(field, [make(value, node)], data=data) elif isinstance(value, list): child = Tree(field, [make(v, node) for v in value], data=data) else: child = Tree(field, [make(value, node)], data=data) children.append(child) return Tree(node.__class__.__name__, children, data=data) pytree = asttokens.ASTTokens(code, parse=True).tree return make(pytree) def stringify(tree): return re.sub('\n *', ' ', tree.__repr__()) # induce a pattern from the specified nodes in the tree def make_pattern(tree, nodes): label = tree.label if not tree: if label == 'Return': return Tree(label, [Tree('$')]) return None if any(n is tree for n in nodes): # remove variable names if label == 'Name': return Tree(label, [Tree('id'), tree[1]]) elif label == 'arg': return Tree(label, [Tree('arg')]) else: return tree subpats = [make_pattern(child, nodes) for child in tree] if not any(subpats): return None # extra nodes to include in the pattern if label in ('Call', 'FunctionDef'): # store function names pat = Tree(label, [tree[0]] + [s for s in subpats if s is not None]) elif label == 'args': # (args arg1 … argN) # store position of matching argument, by including (only) the # root node of all nonmatching arguments pat = Tree(label, [s if s is not None else Tree(tree[i].label) for i, s in enumerate(subpats)]) elif label == 'UnaryOp': # store operator pat = Tree(label, [tree[0], subpats[1]]) elif label == 'BinOp' or label == 'Compare': # store operator and left and/or right side children = [] if subpats[0]: children += [subpats[0]] children += [tree[1]] if subpats[2]: children += [subpats[2]] pat = Tree(label, children) else: # store children if they match some part of the pattern pat = Tree(label, [s for s in subpats if s is not None]) return pat def get_variables(tree): variables = collections.defaultdict(list) for node, parent in tree.subtrees_with_parents(): if not node: continue if node.label == 'Name' and parent.label != 'func': name = node[0][0].label variables[name].append(node) elif node.label == 'FunctionDef': # (FunctionDef name (args (arguments (args arg0 arg1 … argN)))) for arg in node[1][0][0]: if arg: name = arg[0][0].label variables[name].append(arg) return variables literals = {'Num', 'Str', 'NameConstant'} short_range = { 'Return', 'Delete', 'Assign', 'AugAssign', 'Raise', 'BoolOp', 'BinOp', 'UnaryOp', 'Lambda', 'Call', 'Dict', 'Set', 'List' } def get_patterns(code): tree = make_tree(code) variables = get_variables(tree) for var, nodes in variables.items(): for selected in chain( combinations(nodes, 2)): pat = make_pattern(tree, selected) if pat: yield pat, selected # yield patterns for variable-literal / literal-literal pairs # yield patterns for singleton literals # (only within certain statements and expressions) def patterns_with_literals(node): if not node: return if node.label in short_range: vars = [n for n in node.subtrees() if any(n is vn for var_nodes in variables.values() for vn in var_nodes)] lits = [n for n in node.subtrees() if n.label in literals] for selected in chain( combinations(lits, 2), product(lits, vars)): pat = make_pattern(tree, selected) if pat: yield pat, selected else: for child in node: yield from patterns_with_literals(child) # extra nodes for Python for n in tree.subtrees(): # empty return if n.label == 'Return': pat = make_pattern(tree, [n]) if pat: yield pat, [n] yield from patterns_with_literals(tree) def get_attributes(programs): # extract attributes from training data patterns = collections.defaultdict(list) for code in programs: for pat, nodes in get_patterns(code): patterns[pat] += [code] attrs = collections.OrderedDict() for pat, progs in sorted(patterns.items(), key=lambda x: len(x[1]), reverse=True): if len(progs) < 5: break attrs['regex-{}'.format(len(attrs))] = {'desc': pat, 'programs': progs} return attrs