summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--regex/asttokens/__init__.py22
-rw-r--r--regex/asttokens/asttokens.py196
-rw-r--r--regex/asttokens/line_numbers.py71
-rw-r--r--regex/asttokens/mark_tokens.py275
-rw-r--r--regex/asttokens/util.py236
5 files changed, 800 insertions, 0 deletions
diff --git a/regex/asttokens/__init__.py b/regex/asttokens/__init__.py
new file mode 100644
index 0000000..cde4aab
--- /dev/null
+++ b/regex/asttokens/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2016 Grist Labs, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This module enhances the Python AST tree with token and source code information, sufficent to
+detect the source text of each AST node. This is helpful for tools that make source code
+transformations.
+"""
+
+from .line_numbers import LineNumbers
+from .asttokens import ASTTokens
diff --git a/regex/asttokens/asttokens.py b/regex/asttokens/asttokens.py
new file mode 100644
index 0000000..130f53d
--- /dev/null
+++ b/regex/asttokens/asttokens.py
@@ -0,0 +1,196 @@
+# Copyright 2016 Grist Labs, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import ast
+import bisect
+import token
+import tokenize
+import io
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
+from .line_numbers import LineNumbers
+from .util import Token, match_token
+from .mark_tokens import MarkTokens
+
+class ASTTokens(object):
+ """
+ ASTTokens maintains the text of Python code in several forms: as a string, as line numbers, and
+ as tokens, and is used to mark and access token and position information.
+
+ ``source_text`` must be a unicode or UTF8-encoded string. If you pass in UTF8 bytes, remember
+ that all offsets you'll get are to the unicode text, which is available as the ``.text``
+ property.
+
+ If ``parse`` is set, the ``source_text`` will be parsed with ``ast.parse()``, and the resulting
+ tree marked with token info and made available as the ``.tree`` property.
+
+ If ``tree`` is given, it will be marked and made available as the ``.tree`` property. In
+ addition to the trees produced by the ``ast`` module, ASTTokens will also mark trees produced
+ using ``astroid`` library <https://www.astroid.org>.
+
+ If only ``source_text`` is given, you may use ``.mark_tokens(tree)`` to mark the nodes of an AST
+ tree created separately.
+ """
+ def __init__(self, source_text, parse=False, tree=None):
+ if isinstance(source_text, six.binary_type):
+ source_text = source_text.decode('utf8')
+
+ self._tree = ast.parse(source_text) if parse else tree
+
+ self._text = source_text
+ self._line_numbers = LineNumbers(source_text)
+
+ # Tokenize the code.
+ self._tokens = list(self._generate_tokens(source_text))
+
+ # Extract the start positions of all tokens, so that we can quickly map positions to tokens.
+ self._token_offsets = [tok.startpos for tok in self._tokens]
+
+ if self._tree:
+ self.mark_tokens(self._tree)
+
+
+ def mark_tokens(self, root_node):
+ """
+ Given the root of the AST or Astroid tree produced from source_text, visits all nodes marking
+ them with token and position information by adding ``.first_token`` and
+ ``.last_token``attributes. This is done automatically in the constructor when ``parse`` or
+ ``tree`` arguments are set, but may be used manually with a separate AST or Astroid tree.
+ """
+ # The hard work of this class is done by MarkTokens
+ MarkTokens(self).visit_tree(root_node)
+
+
+ def _generate_tokens(self, text):
+ """
+ Generates tokens for the given code.
+ """
+ # This is technically an undocumented API for Python3, but allows us to use the same API as for
+ # Python2. See http://stackoverflow.com/a/4952291/328565.
+ for index, tok in enumerate(tokenize.generate_tokens(io.StringIO(text).readline)):
+ tok_type, tok_str, start, end, line = tok
+ yield Token(tok_type, tok_str, start, end, line, index,
+ self._line_numbers.line_to_offset(start[0], start[1]),
+ self._line_numbers.line_to_offset(end[0], end[1]))
+
+ @property
+ def text(self):
+ """The source code passed into the constructor."""
+ return self._text
+
+ @property
+ def tokens(self):
+ """The list of tokens corresponding to the source code from the constructor."""
+ return self._tokens
+
+ @property
+ def tree(self):
+ """The root of the AST tree passed into the constructor or parsed from the source code."""
+ return self._tree
+
+ def get_token_from_offset(self, offset):
+ """
+ Returns the token containing the given character offset (0-based position in source text),
+ or the preceeding token if the position is between tokens.
+ """
+ return self._tokens[bisect.bisect(self._token_offsets, offset) - 1]
+
+ def get_token(self, lineno, col_offset):
+ """
+ Returns the token containing the given (lineno, col_offset) position, or the preceeding token
+ if the position is between tokens.
+ """
+ # TODO: add test for multibyte unicode. We need to translate offsets from ast module (which
+ # are in utf8) to offsets into the unicode text. tokenize module seems to use unicode offsets
+ # but isn't explicit.
+ return self.get_token_from_offset(self._line_numbers.line_to_offset(lineno, col_offset))
+
+ def get_token_from_utf8(self, lineno, col_offset):
+ """
+ Same as get_token(), but interprets col_offset as a UTF8 offset, which is what `ast` uses.
+ """
+ return self.get_token(lineno, self._line_numbers.from_utf8_col(lineno, col_offset))
+
+ def next_token(self, tok, include_extra=False):
+ """
+ Returns the next token after the given one. If include_extra is True, includes non-coding
+ tokens from the tokenize module, such as NL and COMMENT.
+ """
+ i = tok.index + 1
+ if not include_extra:
+ while self._tokens[i].type >= token.N_TOKENS:
+ i += 1
+ return self._tokens[i]
+
+ def prev_token(self, tok, include_extra=False):
+ """
+ Returns the previous token before the given one. If include_extra is True, includes non-coding
+ tokens from the tokenize module, such as NL and COMMENT.
+ """
+ i = tok.index - 1
+ if not include_extra:
+ while self._tokens[i].type >= token.N_TOKENS:
+ i -= 1
+ return self._tokens[i]
+
+ def find_token(self, start_token, tok_type, tok_str=None, reverse=False):
+ """
+ Looks for the first token, starting at start_token, that matches tok_type and, if given, the
+ token string. Searches backwards if reverse is True.
+ """
+ t = start_token
+ advance = self.prev_token if reverse else self.next_token
+ while not match_token(t, tok_type, tok_str) and not token.ISEOF(t.type):
+ t = advance(t)
+ return t
+
+ def token_range(self, first_token, last_token, include_extra=False):
+ """
+ Yields all tokens in order from first_token through and including last_token. If
+ include_extra is True, includes non-coding tokens such as tokenize.NL and .COMMENT.
+ """
+ for i in xrange(first_token.index, last_token.index + 1):
+ if include_extra or self._tokens[i].type < token.N_TOKENS:
+ yield self._tokens[i]
+
+ def get_tokens(self, node, include_extra=False):
+ """
+ Yields all tokens making up the given node. If include_extra is True, includes non-coding
+ tokens such as tokenize.NL and .COMMENT.
+ """
+ return self.token_range(node.first_token, node.last_token, include_extra=include_extra)
+
+ def get_text_range(self, node):
+ """
+ After mark_tokens() has been called, returns the (startpos, endpos) positions in source text
+ corresponding to the given node. Returns (0, 0) for nodes (like `Load`) that don't correspond
+ to any particular text.
+ """
+ if not hasattr(node, 'first_token'):
+ return (0, 0)
+
+ start = node.first_token.startpos
+ if any(match_token(t, token.NEWLINE) for t in self.get_tokens(node)):
+ # Multi-line nodes would be invalid unless we keep the indentation of the first node.
+ start = self._text.rfind('\n', 0, start) + 1
+
+ return (start, node.last_token.endpos)
+
+ def get_text(self, node):
+ """
+ After mark_tokens() has been called, returns the text corresponding to the given node. Returns
+ '' for nodes (like `Load`) that don't correspond to any particular text.
+ """
+ start, end = self.get_text_range(node)
+ return self._text[start : end]
diff --git a/regex/asttokens/line_numbers.py b/regex/asttokens/line_numbers.py
new file mode 100644
index 0000000..b91b00f
--- /dev/null
+++ b/regex/asttokens/line_numbers.py
@@ -0,0 +1,71 @@
+# Copyright 2016 Grist Labs, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import bisect
+import re
+
+_line_start_re = re.compile(r'^', re.M)
+
+class LineNumbers(object):
+ """
+ Class to convert between character offsets in a text string, and pairs (line, column) of 1-based
+ line and 0-based column numbers, as used by tokens and AST nodes.
+
+ This class expects unicode for input and stores positions in unicode. But it supports
+ translating to and from utf8 offsets, which are used by ast parsing.
+ """
+ def __init__(self, text):
+ # A list of character offsets of each line's first character.
+ self._line_offsets = [m.start(0) for m in _line_start_re.finditer(text)]
+ self._text = text
+ self._text_len = len(text)
+ self._utf8_offset_cache = {} # maps line num to list of char offset for each byte in line
+
+ def from_utf8_col(self, line, utf8_column):
+ """
+ Given a 1-based line number and 0-based utf8 column, returns a 0-based unicode column.
+ """
+ offsets = self._utf8_offset_cache.get(line)
+ if offsets is None:
+ end_offset = self._line_offsets[line] if line < len(self._line_offsets) else self._text_len
+ line_text = self._text[self._line_offsets[line - 1] : end_offset]
+
+ offsets = [i for i,c in enumerate(line_text) for byte in c.encode('utf8')]
+ offsets.append(len(line_text))
+ self._utf8_offset_cache[line] = offsets
+
+ return offsets[max(0, min(len(offsets), utf8_column))]
+
+ def line_to_offset(self, line, column):
+ """
+ Converts 1-based line number and 0-based column to 0-based character offset into text.
+ """
+ line -= 1
+ if line >= len(self._line_offsets):
+ return self._text_len
+ elif line < 0:
+ return 0
+ else:
+ return min(self._line_offsets[line] + max(0, column), self._text_len)
+
+ def offset_to_line(self, offset):
+ """
+ Converts 0-based character offset to pair (line, col) of 1-based line and 0-based column
+ numbers.
+ """
+ offset = max(0, min(self._text_len, offset))
+ line_index = bisect.bisect_right(self._line_offsets, offset) - 1
+ return (line_index + 1, offset - self._line_offsets[line_index])
+
+
diff --git a/regex/asttokens/mark_tokens.py b/regex/asttokens/mark_tokens.py
new file mode 100644
index 0000000..f48fb77
--- /dev/null
+++ b/regex/asttokens/mark_tokens.py
@@ -0,0 +1,275 @@
+# Copyright 2016 Grist Labs, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import six
+import numbers
+import token
+from . import util
+
+
+# Mapping of matching braces. To find a token here, look up token[:2].
+_matching_pairs_left = {
+ (token.OP, '('): (token.OP, ')'),
+ (token.OP, '['): (token.OP, ']'),
+ (token.OP, '{'): (token.OP, '}'),
+}
+
+_matching_pairs_right = {
+ (token.OP, ')'): (token.OP, '('),
+ (token.OP, ']'): (token.OP, '['),
+ (token.OP, '}'): (token.OP, '{'),
+}
+
+
+class MarkTokens(object):
+ """
+ Helper that visits all nodes in the AST tree and assigns .first_token and .last_token attributes
+ to each of them. This is the heart of the token-marking logic.
+ """
+ def __init__(self, code):
+ self._code = code
+ self._methods = util.NodeMethods()
+ self._iter_children = None
+
+ def visit_tree(self, node):
+ self._iter_children = util.iter_children_func(node)
+ util.visit_tree(node, self._visit_before_children, self._visit_after_children)
+
+ def _visit_before_children(self, node, parent_token):
+ col = getattr(node, 'col_offset', None)
+ token = self._code.get_token_from_utf8(node.lineno, col) if col is not None else None
+
+ if not token and util.is_module(node):
+ # We'll assume that a Module node starts at the start of the source code.
+ token = self._code.get_token(1, 0)
+
+ # Use our own token, or our parent's if we don't have one, to pass to child calls as
+ # parent_token argument. The second value becomes the token argument of _visit_after_children.
+ return (token or parent_token, token)
+
+ def _visit_after_children(self, node, parent_token, token):
+ # This processes the node generically first, after all children have been processed.
+
+ # Get the first and last tokens that belong to children. Note how this doesn't assume that we
+ # iterate through children in order that corresponds to occurrence in source code. This
+ # assumption can fail (e.g. with return annotations).
+ first = token
+ last = None
+ for child in self._iter_children(node):
+ if not first or child.first_token.index < first.index:
+ first = child.first_token
+ if not last or child.last_token.index > last.index:
+ last = child.last_token
+
+ # If we don't have a first token from _visit_before_children, and there were no children, then
+ # use the parent's token as the first token.
+ first = first or parent_token
+
+ # If no children, set last token to the first one.
+ last = last or first
+
+ # Statements continue to before NEWLINE. This helps cover a few different cases at once.
+ if util.is_stmt(node):
+ last = self._find_last_in_line(last)
+
+ # Capture any unmatched brackets.
+ first, last = self._expand_to_matching_pairs(first, last, node)
+
+ # Give a chance to node-specific methods to adjust.
+ nfirst, nlast = self._methods.get(self, node.__class__)(node, first, last)
+
+ if (nfirst, nlast) != (first, last):
+ # If anything changed, expand again to capture any unmatched brackets.
+ nfirst, nlast = self._expand_to_matching_pairs(nfirst, nlast, node)
+
+ node.first_token = nfirst
+ node.last_token = nlast
+
+ def _find_last_in_line(self, start_token):
+ try:
+ newline = self._code.find_token(start_token, token.NEWLINE)
+ except IndexError:
+ newline = self._code.find_token(start_token, token.ENDMARKER)
+ return self._code.prev_token(newline)
+
+ def _iter_non_child_tokens(self, first_token, last_token, node):
+ """
+ Generates all tokens in [first_token, last_token] range that do not belong to any children of
+ node. E.g. `foo(bar)` has children `foo` and `bar`, but we would yield the `(`.
+ """
+ tok = first_token
+ for n in self._iter_children(node):
+ for t in self._code.token_range(tok, self._code.prev_token(n.first_token)):
+ yield t
+ if n.last_token.index >= last_token.index:
+ return
+ tok = self._code.next_token(n.last_token)
+
+ for t in self._code.token_range(tok, last_token):
+ yield t
+
+ def _expand_to_matching_pairs(self, first_token, last_token, node):
+ """
+ Scan tokens in [first_token, last_token] range that are between node's children, and for any
+ unmatched brackets, adjust first/last tokens to include the closing pair.
+ """
+ # We look for opening parens/braces among non-child tokens (i.e. tokens between our actual
+ # child nodes). If we find any closing ones, we match them to the opens.
+ to_match_right = []
+ to_match_left = []
+ for tok in self._iter_non_child_tokens(first_token, last_token, node):
+ tok_info = tok[:2]
+ if to_match_right and tok_info == to_match_right[-1]:
+ to_match_right.pop()
+ elif tok_info in _matching_pairs_left:
+ to_match_right.append(_matching_pairs_left[tok_info])
+ elif tok_info in _matching_pairs_right:
+ to_match_left.append(_matching_pairs_right[tok_info])
+
+ # Once done, extend `last_token` to match any unclosed parens/braces.
+ for match in reversed(to_match_right):
+ last = self._code.next_token(last_token)
+ # Allow for a trailing comma before the closing delimiter.
+ if util.match_token(last, token.OP, ','):
+ last = self._code.next_token(last)
+ # Now check for the actual closing delimiter.
+ if util.match_token(last, *match):
+ last_token = last
+
+ # And extend `first_token` to match any unclosed opening parens/braces.
+ for match in to_match_left:
+ first = self._code.prev_token(first_token)
+ if util.match_token(first, *match):
+ first_token = first
+
+ return (first_token, last_token)
+
+ #----------------------------------------------------------------------
+ # Node visitors. Each takes a preliminary first and last tokens, and returns the adjusted pair
+ # that will actually be assigned.
+
+ def visit_default(self, node, first_token, last_token):
+ # pylint: disable=no-self-use
+ # By default, we don't need to adjust the token we computed earlier.
+ return (first_token, last_token)
+
+ def handle_comp(self, open_brace, node, first_token, last_token):
+ # For list/set/dict comprehensions, we only get the token of the first child, so adjust it to
+ # include the opening brace (the closing brace will be matched automatically).
+ before = self._code.prev_token(first_token)
+ util.expect_token(before, token.OP, open_brace)
+ return (before, last_token)
+
+ def visit_listcomp(self, node, first_token, last_token):
+ return self.handle_comp('[', node, first_token, last_token)
+
+ if six.PY2:
+ # We shouldn't do this on PY3 because its SetComp/DictComp already have a correct start.
+ def visit_setcomp(self, node, first_token, last_token):
+ return self.handle_comp('{', node, first_token, last_token)
+
+ def visit_dictcomp(self, node, first_token, last_token):
+ return self.handle_comp('{', node, first_token, last_token)
+
+ def visit_comprehension(self, node, first_token, last_token):
+ # The 'comprehension' node starts with 'for' but we only get first child; we search backwards
+ # to find the 'for' keyword.
+ first = self._code.find_token(first_token, token.NAME, 'for', reverse=True)
+ return (first, last_token)
+
+ def handle_attr(self, node, first_token, last_token):
+ # Attribute node has ".attr" (2 tokens) after the last child.
+ dot = self._code.find_token(last_token, token.OP, '.')
+ name = self._code.next_token(dot)
+ util.expect_token(name, token.NAME)
+ return (first_token, name)
+
+ visit_attribute = handle_attr
+ visit_assignattr = handle_attr
+ visit_delattr = handle_attr
+
+ def handle_doc(self, node, first_token, last_token):
+ # With astroid, nodes that start with a doc-string can have an empty body, in which case we
+ # need to adjust the last token to include the doc string.
+ if not node.body and getattr(node, 'doc', None):
+ last_token = self._code.find_token(last_token, token.STRING)
+ return (first_token, last_token)
+
+ visit_classdef = handle_doc
+ visit_funcdef = handle_doc
+
+ def visit_call(self, node, first_token, last_token):
+ # A function call isn't over until we see a closing paren. Remember that last_token is at the
+ # end of all children, so we are not worried about encountering a paren that belongs to a
+ # child.
+ return (first_token, self._code.find_token(last_token, token.OP, ')'))
+
+ def visit_subscript(self, node, first_token, last_token):
+ # A subscript operations isn't over until we see a closing bracket. Similar to function calls.
+ return (first_token, self._code.find_token(last_token, token.OP, ']'))
+
+ def visit_tuple(self, node, first_token, last_token):
+ # A tuple doesn't include parens; if there is a trailing comma, make it part of the tuple.
+ try:
+ maybe_comma = self._code.next_token(last_token)
+ if util.match_token(maybe_comma, token.OP, ','):
+ last_token = maybe_comma
+ except IndexError:
+ pass
+ return (first_token, last_token)
+
+ def visit_str(self, node, first_token, last_token):
+ # Multiple adjacent STRING tokens form a single string.
+ last = self._code.next_token(last_token)
+ while util.match_token(last, token.STRING):
+ last_token = last
+ last = self._code.next_token(last_token)
+ return (first_token, last_token)
+
+ def visit_num(self, node, first_token, last_token):
+ # A constant like '-1' gets turned into two tokens; this will skip the '-'.
+ while util.match_token(last_token, token.OP):
+ last_token = self._code.next_token(last_token)
+ return (first_token, last_token)
+
+ # In Astroid, the Num and Str nodes are replaced by Const.
+ def visit_const(self, node, first_token, last_token):
+ if isinstance(node.value, numbers.Number):
+ return self.visit_num(node, first_token, last_token)
+ elif isinstance(node.value, six.string_types):
+ return self.visit_str(node, first_token, last_token)
+ return (first_token, last_token)
+
+ def visit_keyword(self, node, first_token, last_token):
+ if node.arg is not None:
+ equals = self._code.find_token(first_token, token.OP, '=', reverse=True)
+ name = self._code.prev_token(equals)
+ util.expect_token(name, token.NAME, node.arg)
+ first_token = name
+ return (first_token, last_token)
+
+ def visit_starred(self, node, first_token, last_token):
+ # Astroid has 'Starred' nodes (for "foo(*bar)" type args), but they need to be adjusted.
+ if not util.match_token(first_token, token.OP, '*'):
+ star = self._code.prev_token(first_token)
+ if util.match_token(star, token.OP, '*'):
+ first_token = star
+ return (first_token, last_token)
+
+ def visit_assignname(self, node, first_token, last_token):
+ # Astroid may turn 'except' clause into AssignName, but we need to adjust it.
+ if util.match_token(first_token, token.NAME, 'except'):
+ colon = self._code.find_token(last_token, token.OP, ':')
+ first_token = last_token = self._code.prev_token(colon)
+ return (first_token, last_token)
diff --git a/regex/asttokens/util.py b/regex/asttokens/util.py
new file mode 100644
index 0000000..4dd2f27
--- /dev/null
+++ b/regex/asttokens/util.py
@@ -0,0 +1,236 @@
+# Copyright 2016 Grist Labs, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import ast
+import collections
+import token
+from six import iteritems
+
+
+def token_repr(tok_type, string):
+ """Returns a human-friendly representation of a token with the given type and string."""
+ # repr() prefixes unicode with 'u' on Python2 but not Python3; strip it out for consistency.
+ return '%s:%s' % (token.tok_name[tok_type], repr(string).lstrip('u'))
+
+
+class Token(collections.namedtuple('Token', 'type string start end line index startpos endpos')):
+ """
+ TokenInfo is an 8-tuple containing the same 5 fields as the tokens produced by the tokenize
+ module, and 3 additional ones useful for this module:
+
+ - [0] .type Token type (see token.py)
+ - [1] .string Token (a string)
+ - [2] .start Starting (row, column) indices of the token (a 2-tuple of ints)
+ - [3] .end Ending (row, column) indices of the token (a 2-tuple of ints)
+ - [4] .line Original line (string)
+ - [5] .index Index of the token in the list of tokens that it belongs to.
+ - [6] .startpos Starting character offset into the input text.
+ - [7] .endpos Ending character offset into the input text.
+ """
+ def __str__(self):
+ return token_repr(self.type, self.string)
+
+
+def match_token(token, tok_type, tok_str=None):
+ """Returns true if token is of the given type and, if a string is given, has that string."""
+ return token.type == tok_type and (tok_str is None or token.string == tok_str)
+
+
+def expect_token(token, tok_type, tok_str=None):
+ """
+ Verifies that the given token is of the expected type. If tok_str is given, the token string
+ is verified too. If the token doesn't match, raises an informative ValueError.
+ """
+ if not match_token(token, tok_type, tok_str):
+ raise ValueError("Expected token %s, got %s on line %s col %s" % (
+ token_repr(tok_type, tok_str), str(token),
+ token.start[0], token.start[1] + 1))
+
+
+def iter_children(node):
+ """
+ Yields all direct children of a AST node, skipping children that are singleton nodes.
+ """
+ return iter_children_astroid(node) if hasattr(node, 'get_children') else iter_children_ast(node)
+
+
+def iter_children_func(node):
+ """
+ Returns a slightly more optimized function to use in place of ``iter_children``, depending on
+ whether ``node`` is from ``ast`` or from the ``astroid`` module.
+ """
+ return iter_children_astroid if hasattr(node, 'get_children') else iter_children_ast
+
+
+def iter_children_astroid(node):
+ # Don't attempt to process children of JoinedStr nodes, which we can't fully handle yet.
+ if is_joined_str(node):
+ return []
+
+ return node.get_children()
+
+
+SINGLETONS = {c for n, c in iteritems(ast.__dict__) if isinstance(c, type) and
+ issubclass(c, (ast.expr_context, ast.boolop, ast.operator, ast.unaryop, ast.cmpop))}
+
+def iter_children_ast(node):
+ # Don't attempt to process children of JoinedStr nodes, which we can't fully handle yet.
+ if is_joined_str(node):
+ return
+
+ for child in ast.iter_child_nodes(node):
+ # Skip singleton children; they don't reflect particular positions in the code and break the
+ # assumptions about the tree consisting of distinct nodes. Note that collecting classes
+ # beforehand and checking them in a set is faster than using isinstance each time.
+ if child.__class__ not in SINGLETONS:
+ yield child
+
+
+stmt_class_names = {n for n, c in iteritems(ast.__dict__)
+ if isinstance(c, type) and issubclass(c, ast.stmt)}
+expr_class_names = ({n for n, c in iteritems(ast.__dict__)
+ if isinstance(c, type) and issubclass(c, ast.expr)} |
+ {'AssignName', 'DelName', 'Const', 'AssignAttr', 'DelAttr'})
+
+# These feel hacky compared to isinstance() but allow us to work with both ast and astroid nodes
+# in the same way, and without even importing astroid.
+def is_expr(node):
+ """Returns whether node is an expression node."""
+ return node.__class__.__name__ in expr_class_names
+
+def is_stmt(node):
+ """Returns whether node is a statement node."""
+ return node.__class__.__name__ in stmt_class_names
+
+def is_module(node):
+ """Returns whether node is a module node."""
+ return node.__class__.__name__ == 'Module'
+
+def is_joined_str(node):
+ """Returns whether node is a JoinedStr node, used to represent f-strings."""
+ # At the moment, nodes below JoinedStr have wrong line/col info, and trying to process them only
+ # leads to errors.
+ return node.__class__.__name__ == 'JoinedStr'
+
+
+# Sentinel value used by visit_tree().
+_PREVISIT = object()
+
+def visit_tree(node, previsit, postvisit):
+ """
+ Scans the tree under the node depth-first using an explicit stack. It avoids implicit recursion
+ via the function call stack to avoid hitting 'maximum recursion depth exceeded' error.
+
+ It calls ``previsit()`` and ``postvisit()`` as follows:
+
+ * ``previsit(node, par_value)`` - should return ``(par_value, value)``
+ ``par_value`` is as returned from ``previsit()`` of the parent.
+
+ * ``postvisit(node, par_value, value)`` - should return ``value``
+ ``par_value`` is as returned from ``previsit()`` of the parent, and ``value`` is as
+ returned from ``previsit()`` of this node itself. The return ``value`` is ignored except
+ the one for the root node, which is returned from the overall ``visit_tree()`` call.
+
+ For the initial node, ``par_value`` is None. Either ``previsit`` and ``postvisit`` may be None.
+ """
+ if not previsit:
+ previsit = lambda node, pvalue: (None, None)
+ if not postvisit:
+ postvisit = lambda node, pvalue, value: None
+
+ iter_children = iter_children_func(node)
+ done = set()
+ ret = None
+ stack = [(node, None, _PREVISIT)]
+ while stack:
+ current, par_value, value = stack.pop()
+ if value is _PREVISIT:
+ assert current not in done # protect againt infinite loop in case of a bad tree.
+ done.add(current)
+
+ pvalue, post_value = previsit(current, par_value)
+ stack.append((current, par_value, post_value))
+
+ # Insert all children in reverse order (so that first child ends up on top of the stack).
+ ins = len(stack)
+ for n in iter_children(current):
+ stack.insert(ins, (n, pvalue, _PREVISIT))
+ else:
+ ret = postvisit(current, par_value, value)
+ return ret
+
+
+
+def walk(node):
+ """
+ Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node``
+ itself), using depth-first pre-order traversal (yieling parents before their children).
+
+ This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and
+ ``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``.
+ """
+ iter_children = iter_children_func(node)
+ done = set()
+ stack = [node]
+ while stack:
+ current = stack.pop()
+ assert current not in done # protect againt infinite loop in case of a bad tree.
+ done.add(current)
+
+ yield current
+
+ # Insert all children in reverse order (so that first child ends up on top of the stack).
+ # This is faster than building a list and reversing it.
+ ins = len(stack)
+ for c in iter_children(current):
+ stack.insert(ins, c)
+
+
+def replace(text, replacements):
+ """
+ Replaces multiple slices of text with new values. This is a convenience method for making code
+ modifications of ranges e.g. as identified by ``ASTTokens.get_text_range(node)``. Replacements is
+ an iterable of ``(start, end, new_text)`` tuples.
+
+ For example, ``replace("this is a test", [(0, 4, "X"), (8, 1, "THE")])`` produces
+ ``"X is THE test"``.
+ """
+ p = 0
+ parts = []
+ for (start, end, new_text) in sorted(replacements):
+ parts.append(text[p:start])
+ parts.append(new_text)
+ p = end
+ parts.append(text[p:])
+ return ''.join(parts)
+
+
+class NodeMethods(object):
+ """
+ Helper to get `visit_{node_type}` methods given a node's class and cache the results.
+ """
+ def __init__(self):
+ self._cache = {}
+
+ def get(self, obj, cls):
+ """
+ Using the lowercase name of the class as node_type, returns `obj.visit_{node_type}`,
+ or `obj.visit_default` if the type-specific method is not found.
+ """
+ method = self._cache.get(cls)
+ if not method:
+ name = "visit_" + cls.__name__.lower()
+ method = getattr(obj, name, obj.visit_default)
+ self._cache[cls] = method
+ return method