diff options
-rw-r--r-- | regex/asttokens/__init__.py | 22 | ||||
-rw-r--r-- | regex/asttokens/asttokens.py | 196 | ||||
-rw-r--r-- | regex/asttokens/line_numbers.py | 71 | ||||
-rw-r--r-- | regex/asttokens/mark_tokens.py | 275 | ||||
-rw-r--r-- | regex/asttokens/util.py | 236 |
5 files changed, 800 insertions, 0 deletions
diff --git a/regex/asttokens/__init__.py b/regex/asttokens/__init__.py new file mode 100644 index 0000000..cde4aab --- /dev/null +++ b/regex/asttokens/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2016 Grist Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module enhances the Python AST tree with token and source code information, sufficent to +detect the source text of each AST node. This is helpful for tools that make source code +transformations. +""" + +from .line_numbers import LineNumbers +from .asttokens import ASTTokens diff --git a/regex/asttokens/asttokens.py b/regex/asttokens/asttokens.py new file mode 100644 index 0000000..130f53d --- /dev/null +++ b/regex/asttokens/asttokens.py @@ -0,0 +1,196 @@ +# Copyright 2016 Grist Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import bisect +import token +import tokenize +import io +import six +from six.moves import xrange # pylint: disable=redefined-builtin +from .line_numbers import LineNumbers +from .util import Token, match_token +from .mark_tokens import MarkTokens + +class ASTTokens(object): + """ + ASTTokens maintains the text of Python code in several forms: as a string, as line numbers, and + as tokens, and is used to mark and access token and position information. + + ``source_text`` must be a unicode or UTF8-encoded string. If you pass in UTF8 bytes, remember + that all offsets you'll get are to the unicode text, which is available as the ``.text`` + property. + + If ``parse`` is set, the ``source_text`` will be parsed with ``ast.parse()``, and the resulting + tree marked with token info and made available as the ``.tree`` property. + + If ``tree`` is given, it will be marked and made available as the ``.tree`` property. In + addition to the trees produced by the ``ast`` module, ASTTokens will also mark trees produced + using ``astroid`` library <https://www.astroid.org>. + + If only ``source_text`` is given, you may use ``.mark_tokens(tree)`` to mark the nodes of an AST + tree created separately. + """ + def __init__(self, source_text, parse=False, tree=None): + if isinstance(source_text, six.binary_type): + source_text = source_text.decode('utf8') + + self._tree = ast.parse(source_text) if parse else tree + + self._text = source_text + self._line_numbers = LineNumbers(source_text) + + # Tokenize the code. + self._tokens = list(self._generate_tokens(source_text)) + + # Extract the start positions of all tokens, so that we can quickly map positions to tokens. + self._token_offsets = [tok.startpos for tok in self._tokens] + + if self._tree: + self.mark_tokens(self._tree) + + + def mark_tokens(self, root_node): + """ + Given the root of the AST or Astroid tree produced from source_text, visits all nodes marking + them with token and position information by adding ``.first_token`` and + ``.last_token``attributes. This is done automatically in the constructor when ``parse`` or + ``tree`` arguments are set, but may be used manually with a separate AST or Astroid tree. + """ + # The hard work of this class is done by MarkTokens + MarkTokens(self).visit_tree(root_node) + + + def _generate_tokens(self, text): + """ + Generates tokens for the given code. + """ + # This is technically an undocumented API for Python3, but allows us to use the same API as for + # Python2. See http://stackoverflow.com/a/4952291/328565. + for index, tok in enumerate(tokenize.generate_tokens(io.StringIO(text).readline)): + tok_type, tok_str, start, end, line = tok + yield Token(tok_type, tok_str, start, end, line, index, + self._line_numbers.line_to_offset(start[0], start[1]), + self._line_numbers.line_to_offset(end[0], end[1])) + + @property + def text(self): + """The source code passed into the constructor.""" + return self._text + + @property + def tokens(self): + """The list of tokens corresponding to the source code from the constructor.""" + return self._tokens + + @property + def tree(self): + """The root of the AST tree passed into the constructor or parsed from the source code.""" + return self._tree + + def get_token_from_offset(self, offset): + """ + Returns the token containing the given character offset (0-based position in source text), + or the preceeding token if the position is between tokens. + """ + return self._tokens[bisect.bisect(self._token_offsets, offset) - 1] + + def get_token(self, lineno, col_offset): + """ + Returns the token containing the given (lineno, col_offset) position, or the preceeding token + if the position is between tokens. + """ + # TODO: add test for multibyte unicode. We need to translate offsets from ast module (which + # are in utf8) to offsets into the unicode text. tokenize module seems to use unicode offsets + # but isn't explicit. + return self.get_token_from_offset(self._line_numbers.line_to_offset(lineno, col_offset)) + + def get_token_from_utf8(self, lineno, col_offset): + """ + Same as get_token(), but interprets col_offset as a UTF8 offset, which is what `ast` uses. + """ + return self.get_token(lineno, self._line_numbers.from_utf8_col(lineno, col_offset)) + + def next_token(self, tok, include_extra=False): + """ + Returns the next token after the given one. If include_extra is True, includes non-coding + tokens from the tokenize module, such as NL and COMMENT. + """ + i = tok.index + 1 + if not include_extra: + while self._tokens[i].type >= token.N_TOKENS: + i += 1 + return self._tokens[i] + + def prev_token(self, tok, include_extra=False): + """ + Returns the previous token before the given one. If include_extra is True, includes non-coding + tokens from the tokenize module, such as NL and COMMENT. + """ + i = tok.index - 1 + if not include_extra: + while self._tokens[i].type >= token.N_TOKENS: + i -= 1 + return self._tokens[i] + + def find_token(self, start_token, tok_type, tok_str=None, reverse=False): + """ + Looks for the first token, starting at start_token, that matches tok_type and, if given, the + token string. Searches backwards if reverse is True. + """ + t = start_token + advance = self.prev_token if reverse else self.next_token + while not match_token(t, tok_type, tok_str) and not token.ISEOF(t.type): + t = advance(t) + return t + + def token_range(self, first_token, last_token, include_extra=False): + """ + Yields all tokens in order from first_token through and including last_token. If + include_extra is True, includes non-coding tokens such as tokenize.NL and .COMMENT. + """ + for i in xrange(first_token.index, last_token.index + 1): + if include_extra or self._tokens[i].type < token.N_TOKENS: + yield self._tokens[i] + + def get_tokens(self, node, include_extra=False): + """ + Yields all tokens making up the given node. If include_extra is True, includes non-coding + tokens such as tokenize.NL and .COMMENT. + """ + return self.token_range(node.first_token, node.last_token, include_extra=include_extra) + + def get_text_range(self, node): + """ + After mark_tokens() has been called, returns the (startpos, endpos) positions in source text + corresponding to the given node. Returns (0, 0) for nodes (like `Load`) that don't correspond + to any particular text. + """ + if not hasattr(node, 'first_token'): + return (0, 0) + + start = node.first_token.startpos + if any(match_token(t, token.NEWLINE) for t in self.get_tokens(node)): + # Multi-line nodes would be invalid unless we keep the indentation of the first node. + start = self._text.rfind('\n', 0, start) + 1 + + return (start, node.last_token.endpos) + + def get_text(self, node): + """ + After mark_tokens() has been called, returns the text corresponding to the given node. Returns + '' for nodes (like `Load`) that don't correspond to any particular text. + """ + start, end = self.get_text_range(node) + return self._text[start : end] diff --git a/regex/asttokens/line_numbers.py b/regex/asttokens/line_numbers.py new file mode 100644 index 0000000..b91b00f --- /dev/null +++ b/regex/asttokens/line_numbers.py @@ -0,0 +1,71 @@ +# Copyright 2016 Grist Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bisect +import re + +_line_start_re = re.compile(r'^', re.M) + +class LineNumbers(object): + """ + Class to convert between character offsets in a text string, and pairs (line, column) of 1-based + line and 0-based column numbers, as used by tokens and AST nodes. + + This class expects unicode for input and stores positions in unicode. But it supports + translating to and from utf8 offsets, which are used by ast parsing. + """ + def __init__(self, text): + # A list of character offsets of each line's first character. + self._line_offsets = [m.start(0) for m in _line_start_re.finditer(text)] + self._text = text + self._text_len = len(text) + self._utf8_offset_cache = {} # maps line num to list of char offset for each byte in line + + def from_utf8_col(self, line, utf8_column): + """ + Given a 1-based line number and 0-based utf8 column, returns a 0-based unicode column. + """ + offsets = self._utf8_offset_cache.get(line) + if offsets is None: + end_offset = self._line_offsets[line] if line < len(self._line_offsets) else self._text_len + line_text = self._text[self._line_offsets[line - 1] : end_offset] + + offsets = [i for i,c in enumerate(line_text) for byte in c.encode('utf8')] + offsets.append(len(line_text)) + self._utf8_offset_cache[line] = offsets + + return offsets[max(0, min(len(offsets), utf8_column))] + + def line_to_offset(self, line, column): + """ + Converts 1-based line number and 0-based column to 0-based character offset into text. + """ + line -= 1 + if line >= len(self._line_offsets): + return self._text_len + elif line < 0: + return 0 + else: + return min(self._line_offsets[line] + max(0, column), self._text_len) + + def offset_to_line(self, offset): + """ + Converts 0-based character offset to pair (line, col) of 1-based line and 0-based column + numbers. + """ + offset = max(0, min(self._text_len, offset)) + line_index = bisect.bisect_right(self._line_offsets, offset) - 1 + return (line_index + 1, offset - self._line_offsets[line_index]) + + diff --git a/regex/asttokens/mark_tokens.py b/regex/asttokens/mark_tokens.py new file mode 100644 index 0000000..f48fb77 --- /dev/null +++ b/regex/asttokens/mark_tokens.py @@ -0,0 +1,275 @@ +# Copyright 2016 Grist Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import numbers +import token +from . import util + + +# Mapping of matching braces. To find a token here, look up token[:2]. +_matching_pairs_left = { + (token.OP, '('): (token.OP, ')'), + (token.OP, '['): (token.OP, ']'), + (token.OP, '{'): (token.OP, '}'), +} + +_matching_pairs_right = { + (token.OP, ')'): (token.OP, '('), + (token.OP, ']'): (token.OP, '['), + (token.OP, '}'): (token.OP, '{'), +} + + +class MarkTokens(object): + """ + Helper that visits all nodes in the AST tree and assigns .first_token and .last_token attributes + to each of them. This is the heart of the token-marking logic. + """ + def __init__(self, code): + self._code = code + self._methods = util.NodeMethods() + self._iter_children = None + + def visit_tree(self, node): + self._iter_children = util.iter_children_func(node) + util.visit_tree(node, self._visit_before_children, self._visit_after_children) + + def _visit_before_children(self, node, parent_token): + col = getattr(node, 'col_offset', None) + token = self._code.get_token_from_utf8(node.lineno, col) if col is not None else None + + if not token and util.is_module(node): + # We'll assume that a Module node starts at the start of the source code. + token = self._code.get_token(1, 0) + + # Use our own token, or our parent's if we don't have one, to pass to child calls as + # parent_token argument. The second value becomes the token argument of _visit_after_children. + return (token or parent_token, token) + + def _visit_after_children(self, node, parent_token, token): + # This processes the node generically first, after all children have been processed. + + # Get the first and last tokens that belong to children. Note how this doesn't assume that we + # iterate through children in order that corresponds to occurrence in source code. This + # assumption can fail (e.g. with return annotations). + first = token + last = None + for child in self._iter_children(node): + if not first or child.first_token.index < first.index: + first = child.first_token + if not last or child.last_token.index > last.index: + last = child.last_token + + # If we don't have a first token from _visit_before_children, and there were no children, then + # use the parent's token as the first token. + first = first or parent_token + + # If no children, set last token to the first one. + last = last or first + + # Statements continue to before NEWLINE. This helps cover a few different cases at once. + if util.is_stmt(node): + last = self._find_last_in_line(last) + + # Capture any unmatched brackets. + first, last = self._expand_to_matching_pairs(first, last, node) + + # Give a chance to node-specific methods to adjust. + nfirst, nlast = self._methods.get(self, node.__class__)(node, first, last) + + if (nfirst, nlast) != (first, last): + # If anything changed, expand again to capture any unmatched brackets. + nfirst, nlast = self._expand_to_matching_pairs(nfirst, nlast, node) + + node.first_token = nfirst + node.last_token = nlast + + def _find_last_in_line(self, start_token): + try: + newline = self._code.find_token(start_token, token.NEWLINE) + except IndexError: + newline = self._code.find_token(start_token, token.ENDMARKER) + return self._code.prev_token(newline) + + def _iter_non_child_tokens(self, first_token, last_token, node): + """ + Generates all tokens in [first_token, last_token] range that do not belong to any children of + node. E.g. `foo(bar)` has children `foo` and `bar`, but we would yield the `(`. + """ + tok = first_token + for n in self._iter_children(node): + for t in self._code.token_range(tok, self._code.prev_token(n.first_token)): + yield t + if n.last_token.index >= last_token.index: + return + tok = self._code.next_token(n.last_token) + + for t in self._code.token_range(tok, last_token): + yield t + + def _expand_to_matching_pairs(self, first_token, last_token, node): + """ + Scan tokens in [first_token, last_token] range that are between node's children, and for any + unmatched brackets, adjust first/last tokens to include the closing pair. + """ + # We look for opening parens/braces among non-child tokens (i.e. tokens between our actual + # child nodes). If we find any closing ones, we match them to the opens. + to_match_right = [] + to_match_left = [] + for tok in self._iter_non_child_tokens(first_token, last_token, node): + tok_info = tok[:2] + if to_match_right and tok_info == to_match_right[-1]: + to_match_right.pop() + elif tok_info in _matching_pairs_left: + to_match_right.append(_matching_pairs_left[tok_info]) + elif tok_info in _matching_pairs_right: + to_match_left.append(_matching_pairs_right[tok_info]) + + # Once done, extend `last_token` to match any unclosed parens/braces. + for match in reversed(to_match_right): + last = self._code.next_token(last_token) + # Allow for a trailing comma before the closing delimiter. + if util.match_token(last, token.OP, ','): + last = self._code.next_token(last) + # Now check for the actual closing delimiter. + if util.match_token(last, *match): + last_token = last + + # And extend `first_token` to match any unclosed opening parens/braces. + for match in to_match_left: + first = self._code.prev_token(first_token) + if util.match_token(first, *match): + first_token = first + + return (first_token, last_token) + + #---------------------------------------------------------------------- + # Node visitors. Each takes a preliminary first and last tokens, and returns the adjusted pair + # that will actually be assigned. + + def visit_default(self, node, first_token, last_token): + # pylint: disable=no-self-use + # By default, we don't need to adjust the token we computed earlier. + return (first_token, last_token) + + def handle_comp(self, open_brace, node, first_token, last_token): + # For list/set/dict comprehensions, we only get the token of the first child, so adjust it to + # include the opening brace (the closing brace will be matched automatically). + before = self._code.prev_token(first_token) + util.expect_token(before, token.OP, open_brace) + return (before, last_token) + + def visit_listcomp(self, node, first_token, last_token): + return self.handle_comp('[', node, first_token, last_token) + + if six.PY2: + # We shouldn't do this on PY3 because its SetComp/DictComp already have a correct start. + def visit_setcomp(self, node, first_token, last_token): + return self.handle_comp('{', node, first_token, last_token) + + def visit_dictcomp(self, node, first_token, last_token): + return self.handle_comp('{', node, first_token, last_token) + + def visit_comprehension(self, node, first_token, last_token): + # The 'comprehension' node starts with 'for' but we only get first child; we search backwards + # to find the 'for' keyword. + first = self._code.find_token(first_token, token.NAME, 'for', reverse=True) + return (first, last_token) + + def handle_attr(self, node, first_token, last_token): + # Attribute node has ".attr" (2 tokens) after the last child. + dot = self._code.find_token(last_token, token.OP, '.') + name = self._code.next_token(dot) + util.expect_token(name, token.NAME) + return (first_token, name) + + visit_attribute = handle_attr + visit_assignattr = handle_attr + visit_delattr = handle_attr + + def handle_doc(self, node, first_token, last_token): + # With astroid, nodes that start with a doc-string can have an empty body, in which case we + # need to adjust the last token to include the doc string. + if not node.body and getattr(node, 'doc', None): + last_token = self._code.find_token(last_token, token.STRING) + return (first_token, last_token) + + visit_classdef = handle_doc + visit_funcdef = handle_doc + + def visit_call(self, node, first_token, last_token): + # A function call isn't over until we see a closing paren. Remember that last_token is at the + # end of all children, so we are not worried about encountering a paren that belongs to a + # child. + return (first_token, self._code.find_token(last_token, token.OP, ')')) + + def visit_subscript(self, node, first_token, last_token): + # A subscript operations isn't over until we see a closing bracket. Similar to function calls. + return (first_token, self._code.find_token(last_token, token.OP, ']')) + + def visit_tuple(self, node, first_token, last_token): + # A tuple doesn't include parens; if there is a trailing comma, make it part of the tuple. + try: + maybe_comma = self._code.next_token(last_token) + if util.match_token(maybe_comma, token.OP, ','): + last_token = maybe_comma + except IndexError: + pass + return (first_token, last_token) + + def visit_str(self, node, first_token, last_token): + # Multiple adjacent STRING tokens form a single string. + last = self._code.next_token(last_token) + while util.match_token(last, token.STRING): + last_token = last + last = self._code.next_token(last_token) + return (first_token, last_token) + + def visit_num(self, node, first_token, last_token): + # A constant like '-1' gets turned into two tokens; this will skip the '-'. + while util.match_token(last_token, token.OP): + last_token = self._code.next_token(last_token) + return (first_token, last_token) + + # In Astroid, the Num and Str nodes are replaced by Const. + def visit_const(self, node, first_token, last_token): + if isinstance(node.value, numbers.Number): + return self.visit_num(node, first_token, last_token) + elif isinstance(node.value, six.string_types): + return self.visit_str(node, first_token, last_token) + return (first_token, last_token) + + def visit_keyword(self, node, first_token, last_token): + if node.arg is not None: + equals = self._code.find_token(first_token, token.OP, '=', reverse=True) + name = self._code.prev_token(equals) + util.expect_token(name, token.NAME, node.arg) + first_token = name + return (first_token, last_token) + + def visit_starred(self, node, first_token, last_token): + # Astroid has 'Starred' nodes (for "foo(*bar)" type args), but they need to be adjusted. + if not util.match_token(first_token, token.OP, '*'): + star = self._code.prev_token(first_token) + if util.match_token(star, token.OP, '*'): + first_token = star + return (first_token, last_token) + + def visit_assignname(self, node, first_token, last_token): + # Astroid may turn 'except' clause into AssignName, but we need to adjust it. + if util.match_token(first_token, token.NAME, 'except'): + colon = self._code.find_token(last_token, token.OP, ':') + first_token = last_token = self._code.prev_token(colon) + return (first_token, last_token) diff --git a/regex/asttokens/util.py b/regex/asttokens/util.py new file mode 100644 index 0000000..4dd2f27 --- /dev/null +++ b/regex/asttokens/util.py @@ -0,0 +1,236 @@ +# Copyright 2016 Grist Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import collections +import token +from six import iteritems + + +def token_repr(tok_type, string): + """Returns a human-friendly representation of a token with the given type and string.""" + # repr() prefixes unicode with 'u' on Python2 but not Python3; strip it out for consistency. + return '%s:%s' % (token.tok_name[tok_type], repr(string).lstrip('u')) + + +class Token(collections.namedtuple('Token', 'type string start end line index startpos endpos')): + """ + TokenInfo is an 8-tuple containing the same 5 fields as the tokens produced by the tokenize + module, and 3 additional ones useful for this module: + + - [0] .type Token type (see token.py) + - [1] .string Token (a string) + - [2] .start Starting (row, column) indices of the token (a 2-tuple of ints) + - [3] .end Ending (row, column) indices of the token (a 2-tuple of ints) + - [4] .line Original line (string) + - [5] .index Index of the token in the list of tokens that it belongs to. + - [6] .startpos Starting character offset into the input text. + - [7] .endpos Ending character offset into the input text. + """ + def __str__(self): + return token_repr(self.type, self.string) + + +def match_token(token, tok_type, tok_str=None): + """Returns true if token is of the given type and, if a string is given, has that string.""" + return token.type == tok_type and (tok_str is None or token.string == tok_str) + + +def expect_token(token, tok_type, tok_str=None): + """ + Verifies that the given token is of the expected type. If tok_str is given, the token string + is verified too. If the token doesn't match, raises an informative ValueError. + """ + if not match_token(token, tok_type, tok_str): + raise ValueError("Expected token %s, got %s on line %s col %s" % ( + token_repr(tok_type, tok_str), str(token), + token.start[0], token.start[1] + 1)) + + +def iter_children(node): + """ + Yields all direct children of a AST node, skipping children that are singleton nodes. + """ + return iter_children_astroid(node) if hasattr(node, 'get_children') else iter_children_ast(node) + + +def iter_children_func(node): + """ + Returns a slightly more optimized function to use in place of ``iter_children``, depending on + whether ``node`` is from ``ast`` or from the ``astroid`` module. + """ + return iter_children_astroid if hasattr(node, 'get_children') else iter_children_ast + + +def iter_children_astroid(node): + # Don't attempt to process children of JoinedStr nodes, which we can't fully handle yet. + if is_joined_str(node): + return [] + + return node.get_children() + + +SINGLETONS = {c for n, c in iteritems(ast.__dict__) if isinstance(c, type) and + issubclass(c, (ast.expr_context, ast.boolop, ast.operator, ast.unaryop, ast.cmpop))} + +def iter_children_ast(node): + # Don't attempt to process children of JoinedStr nodes, which we can't fully handle yet. + if is_joined_str(node): + return + + for child in ast.iter_child_nodes(node): + # Skip singleton children; they don't reflect particular positions in the code and break the + # assumptions about the tree consisting of distinct nodes. Note that collecting classes + # beforehand and checking them in a set is faster than using isinstance each time. + if child.__class__ not in SINGLETONS: + yield child + + +stmt_class_names = {n for n, c in iteritems(ast.__dict__) + if isinstance(c, type) and issubclass(c, ast.stmt)} +expr_class_names = ({n for n, c in iteritems(ast.__dict__) + if isinstance(c, type) and issubclass(c, ast.expr)} | + {'AssignName', 'DelName', 'Const', 'AssignAttr', 'DelAttr'}) + +# These feel hacky compared to isinstance() but allow us to work with both ast and astroid nodes +# in the same way, and without even importing astroid. +def is_expr(node): + """Returns whether node is an expression node.""" + return node.__class__.__name__ in expr_class_names + +def is_stmt(node): + """Returns whether node is a statement node.""" + return node.__class__.__name__ in stmt_class_names + +def is_module(node): + """Returns whether node is a module node.""" + return node.__class__.__name__ == 'Module' + +def is_joined_str(node): + """Returns whether node is a JoinedStr node, used to represent f-strings.""" + # At the moment, nodes below JoinedStr have wrong line/col info, and trying to process them only + # leads to errors. + return node.__class__.__name__ == 'JoinedStr' + + +# Sentinel value used by visit_tree(). +_PREVISIT = object() + +def visit_tree(node, previsit, postvisit): + """ + Scans the tree under the node depth-first using an explicit stack. It avoids implicit recursion + via the function call stack to avoid hitting 'maximum recursion depth exceeded' error. + + It calls ``previsit()`` and ``postvisit()`` as follows: + + * ``previsit(node, par_value)`` - should return ``(par_value, value)`` + ``par_value`` is as returned from ``previsit()`` of the parent. + + * ``postvisit(node, par_value, value)`` - should return ``value`` + ``par_value`` is as returned from ``previsit()`` of the parent, and ``value`` is as + returned from ``previsit()`` of this node itself. The return ``value`` is ignored except + the one for the root node, which is returned from the overall ``visit_tree()`` call. + + For the initial node, ``par_value`` is None. Either ``previsit`` and ``postvisit`` may be None. + """ + if not previsit: + previsit = lambda node, pvalue: (None, None) + if not postvisit: + postvisit = lambda node, pvalue, value: None + + iter_children = iter_children_func(node) + done = set() + ret = None + stack = [(node, None, _PREVISIT)] + while stack: + current, par_value, value = stack.pop() + if value is _PREVISIT: + assert current not in done # protect againt infinite loop in case of a bad tree. + done.add(current) + + pvalue, post_value = previsit(current, par_value) + stack.append((current, par_value, post_value)) + + # Insert all children in reverse order (so that first child ends up on top of the stack). + ins = len(stack) + for n in iter_children(current): + stack.insert(ins, (n, pvalue, _PREVISIT)) + else: + ret = postvisit(current, par_value, value) + return ret + + + +def walk(node): + """ + Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node`` + itself), using depth-first pre-order traversal (yieling parents before their children). + + This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and + ``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``. + """ + iter_children = iter_children_func(node) + done = set() + stack = [node] + while stack: + current = stack.pop() + assert current not in done # protect againt infinite loop in case of a bad tree. + done.add(current) + + yield current + + # Insert all children in reverse order (so that first child ends up on top of the stack). + # This is faster than building a list and reversing it. + ins = len(stack) + for c in iter_children(current): + stack.insert(ins, c) + + +def replace(text, replacements): + """ + Replaces multiple slices of text with new values. This is a convenience method for making code + modifications of ranges e.g. as identified by ``ASTTokens.get_text_range(node)``. Replacements is + an iterable of ``(start, end, new_text)`` tuples. + + For example, ``replace("this is a test", [(0, 4, "X"), (8, 1, "THE")])`` produces + ``"X is THE test"``. + """ + p = 0 + parts = [] + for (start, end, new_text) in sorted(replacements): + parts.append(text[p:start]) + parts.append(new_text) + p = end + parts.append(text[p:]) + return ''.join(parts) + + +class NodeMethods(object): + """ + Helper to get `visit_{node_type}` methods given a node's class and cache the results. + """ + def __init__(self): + self._cache = {} + + def get(self, obj, cls): + """ + Using the lowercase name of the class as node_type, returns `obj.visit_{node_type}`, + or `obj.visit_default` if the type-specific method is not found. + """ + method = self._cache.get(cls) + if not method: + name = "visit_" + cls.__name__.lower() + method = getattr(obj, name, obj.visit_default) + self._cache[cls] = method + return method |