Process, Analyze, and Transform Python Code with ASTs
Stefanie Molin
Bio
- 👩🏽💻 Software engineer at Bloomberg in NYC
- ✨ Core developer of numpydoc and creator of numpydoc's pre-commit hook, which uses ASTs
- ✍ Author of "Hands-On Data Analysis with Pandas"
- 🎓 Bachelor's degree in operations research from Columbia University
- 🎓 Master's degree in computer science from Georgia Tech
Prerequisites
- Comfort writing Python code, especially object-oriented programming
- Have Python and Git installed on your computer, as well as a text editor for writing code (e.g., Visual Studio Code)
- Fork and clone this repository: github.com/stefmolin/ast-workshop
- Open up these slides in your browser and use the arrow keys to follow along: stefaniemolin.com/ast-workshop
- Open up the documentation for the
astmodule in your browser to consult during the exercises: docs.python.org/3/library/ast.html
Introduction to ASTs
Abstract Syntax Tree (AST)
- Represents the structure of the source code as a tree
- Nodes in the tree are language constructs (e.g., module, class, function)
- Each node has a single parent (e.g., a class is a child of a single module)
- Parent nodes can have multiple children (e.g., a class can have several methods)
Let's see what this code snippet (greet.py) looks like when represented as an AST:
class Greeter:
def __init__(self, enthusiasm: int = 1) -> None:
self.enthusiasm = enthusiasm
def greet(self, name: str = 'World') -> str:
return f'Hello, {name}{"!" * self.enthusiasm}'
The AST for the
greet.py snippet visualized with Graphviz.
Popular open source tools that use ASTs
We can generate ASTs without installing the codebase being analyzed or its dependencies, which makes it a popular choice to building tools:
-
Linters and formatters, like
ruff(Rust) andblack(Python) -
Documentation tools, like
sphinxand thenumpydoc-validationpre-commit hook -
Automatic Python syntax upgrade tools, like
pyupgrade -
Next-generation notebooks, like
marimo -
Type checkers, like
mypy -
Code security tools, like
bandit -
Code and testing coverage tools, like
vultureandcoverage.py -
Testing frameworks that instrument your code or generate tests based on it, like
hypothesisandpytest
ASTs in Python
- Represent syntactically-correct Python code (cannot be generated in the presence of syntax errors)
- Created by the parser as an intermediary step when compiling source code into byte code (necessary to run it)
-
Available in the standard library via the
astmodule
Parsing Python source code into an AST
1. Read in the source code
>>> from pathlib import Path
>>> source_code = Path('snippets/greet.py').read_text()
2. Parse it with the ast module
If the code is syntactically-correct, we get an AST back:
>>> import ast
>>> tree = ast.parse(source_code)
>>> print(type(tree))
<class 'ast.Module'>
Inspecting the AST
We can use the ast.dump() function to display the AST:
The root node is an ast.Module node:
It contains everything else in its body attribute:
The greet.py file first defines a class, named Greeter:
The ast.ClassDef node also contains the body of the Greeter class:
The first entry is the Greeter.__init__() method:
The ast.FunctionDef node includes information about the arguments:
Its body contains the AST representation of the function's code:
The return annotation is stored in the returns attribute:
The final entry is the Greeter.greet() method:
>>> print(ast.dump(tree, indent=2))
Module(
body=[
ClassDef(
name='Greeter',
body=[
FunctionDef(
name='__init__',
args=arguments(
args=[
arg(arg='self'),
arg(
arg='enthusiasm',
annotation=Name(id='int', ctx=Load()))],
defaults=[
Constant(value=1)]),
body=[
Assign(
targets=[
Attribute(
value=Name(id='self', ctx=Load()),
attr='enthusiasm',
ctx=Store())],
value=Name(id='enthusiasm', ctx=Load()))],
returns=Constant(value=None)),
FunctionDef(
name='greet',
args=arguments(
args=[
arg(arg='self'),
arg(
arg='name',
annotation=Name(id='str', ctx=Load()))],
defaults=[
Constant(value='World')]),
body=[
Return(
value=JoinedStr(
values=[
Constant(value='Hello, '),
FormattedValue(
value=Name(id='name', ctx=Load()),
conversion=-1),
FormattedValue(
value=BinOp(
left=Constant(value='!'),
op=Mult(),
right=Attribute(
value=Name(id='self', ctx=Load()),
attr='enthusiasm',
ctx=Load())),
conversion=-1)]))],
returns=Name(id='str', ctx=Load()))])])
Exercise 1
- Try passing source code that has a
SyntaxErrorinto theast.parse()function. What happens? - What about if the code has an error unrelated to syntax, for instance, a
NameErrororTypeError?
Example solution
1. Syntactically-incorrect source code
Let's use the following malformed import statement as an example of invalid source code:
>>> import timedelta from datetime
File "<stdin>", line 1
import timedelta from datetime
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
SyntaxError: Did you mean to use 'from ... import ...' instead?
We also encounter a SyntaxError when attempting to parse this into an AST:
>>> import ast
>>> tree = ast.parse('import timedelta from datetime')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ast.parse('import timedelta from datetime')
~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../ast.py", line 46, in parse
return compile(source, filename, mode, flags,
_feature_version=feature_version,
optimize=optimize)
File "<unknown>", line 1
ast.parse('import timedelta from datetime')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
SyntaxError: Did you mean to use 'from ... import ...' instead?
2. Syntactically-correct source code with logic errors
The following code raises a NameError at runtime:
>>> a + 5
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
a + 5
^
NameError: name 'a' is not defined
However, it is syntactically-correct, so we can parse it into an AST:
>>> ast.parse('a + 5')
Module(body=[Expr(value=BinOp(...))], type_ignores=[])
Working with ASTs
AST node attributes
In addition to being a highly-nested structure, attributes containing nodes may be named differently across node types. To see this, let's take a look at the AST for the following snippet in assert.py:
def duplicate_list(x):
assert isinstance(x, list)
return x + x
The AST for the
assert.py snippet with node attributes visualized with Graphviz.
Traversing the AST
To effectively analyze code using the AST, we need to traverse it and inspect the nodes we care about. Depending on how much of the tree we want to explore and how much context we need about each node, there are different approaches. Let's walk through the different ways using the assert.py example:
import ast
from pathlib import Path
source_code = Path('snippets/assert.py').read_text()
tree = ast.parse(source_code)
ast.iter_fields()
We can use the ast.iter_fields() function to iterate over all fields that a node has. Our AST is rooted at an ast.Module node, so there isn't much here:
>>> print(list(ast.iter_fields(tree)))
[('body', [<ast.FunctionDef at 0x1086bea10>]),
('type_ignores', [])]
However, if we look at this for the ast.FunctionDef node in the body of the ast.Module node, we have more information:
>>> func_def = tree.body[0]
>>> print(list(ast.iter_fields(func_def)))
[('name', 'duplicate_list'),
('args', <ast.arguments at 0x1085794e0>),
('body', [<ast.Assert at 0x10884d6c0>,
<ast.Return at 0x10884d9f0>]),
('decorator_list', []),
('returns', None),
('type_comment', None)]
ast.iter_child_nodes()
The ast.iter_fields() function is helpful when figuring out how to work with individual node types. The ast.iter_child_nodes() builds on top of this to traverse the tree starting at a given node. It yields all nodes it encounters along the way that are direct children of the starting node (they can be in any field, but they cannot be grandchildren, like the children of the ast.Assert node below would be to the ast.FunctionDef node):
>>> print(list(ast.iter_child_nodes(func_def)))
[<ast.arguments at 0x1085794e0>,
<ast.Assert at 0x10884d6c0>,
<ast.Return at 0x10884d9f0>]
To traverse the entire tree, we need the recursive behavior provided in the ast.walk() function or the ast.NodeVisitor/ast.NodeTransformer classes.
Each of these builds upon the functions we just looked at (ast.iter_fields() and ast.iter_child_nodes()).
Let's start with the ast.walk() function.
ast.walk()
The ast.walk() function recursively yields all descendant nodes in the AST. Let's use it to make sure all assert calls provide a message when the assert is false and an AssertionError is raised.
For those unfamiliar with the syntax, here's a comparison using the contents of the assert.py snippet:
# without custom message
def duplicate_list(x):
assert isinstance(x, list)
return x + x
# with custom message
def duplicate_list(x):
assert isinstance(x, list), 'Input is not a list'
return x + x
Modifying code before running it with ast.walk()
The ast.walk() function yields all nodes descending from tree:
We want to modify all ast.Assert nodes that do not have a message (msg):
We set the msg to a placeholder value, so it's easy to find in the logs:
All nodes must have line numbers in order to compile the AST into a code object:
The compile() function turns our modified AST into a code object:
We can execute code objects with the exec() function:
This runs the function definition for duplicate_list(), so we can now call it:
The input we passed fails the assert, raising an AssertionError:
Notice that we get the message we injected when we modified the AST:
>>> for node in ast.walk(tree):
... if isinstance(node, ast.Assert) and not node.msg:
... node.msg = ast.Constant('TODO: Add failure info')
... ast.fix_missing_locations(node)
>>>
>>> code = compile(tree, '<ast_workshop>', 'exec')
>>> exec(code)
>>> duplicate_list(1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
duplicate_list(1)
~~~~~~~~~~~~~~^^^
File "<ast_workshop>", line 3, in duplicate_list
AssertionError: TODO: Add failure info
Can we convert this back into source code to save it?
With a small example like this, we can also use the ast.unparse() function to convert the modified AST back into Python source code:
>>> print(ast.unparse(tree))
def duplicate_list(x):
assert isinstance(x, list), 'TODO: Add failure info'
return x + x
The ast.unparse() function comes with some caveats:
- It's not recommended for larger trees since it can hit recursion limits.
- If we first convert source code to an AST, and then attempt to convert it back without any changes, the result will be equivalent, but not necessarily equal to the original.
Introduced in Python 3.14, the ast.compare() function compares ASTs recursively. Here, we can see an example of structurally-equivalent, but stylistically-different programs:
# check if the ASTs are equivalent
>>> ast.compare(
... ast.parse('a + (b * c)'),
... ast.parse('a + b * c'),
... )
True
# compare line numbers and column offsets as well
>>> ast.compare(
... ast.parse('a + (b * c)'),
... ast.parse('a + b * c'),
... compare_attributes=True,
... )
False
If we scale this example up to a slightly larger program with comments and more stylistic formatting, there will be even more differences:
import contextlib
def strip_password(
credentials: dict[str, str]
) -> None:
'''
Strip out the password from the credentials.
'''
# remove the password if it is there
with contextlib.suppress(KeyError):
del credentials["password"]
When we parse this into an AST and back again (a process called round-trip parsing), the resulting code is equivalent, but different:
import contextlib
def strip_password(credentials: dict[str, str]) -> None:
"""
Strip out the password from the credentials.
"""
with contextlib.suppress(KeyError):
del credentials['password']
There are no longer two blank lines after the import:
The function definition is now written entirely on one line:
The docstring now uses """ instead of ''':
There is no longer a blank line between the docstring and the code:
The comment has been removed:
Single quotes are used for keying into the dictionary:
import contextlib
-
- def strip_password(
+ def strip_password(credentials: dict[str, str]) -> None:
- credentials: dict[str, str]
- ) -> None:
- '''
+ """
Strip out the password from the credentials.
- '''
+ """
-
- # remove the password if it is there
with contextlib.suppress(KeyError):
- del credentials["password"]
+ del credentials['password']
Exercise 2
Use the ast.walk() function and the ast.get_docstring() function to traverse the AST for the greet.py snippet and report any items that are missing docstrings.
Example solution
Similar setup to the previous examples, except we also import contextlib:
Multiple node types can have docstrings, so we don't limit to a type here:
We try to access each node's docstring and suppress any TypeErrors:
If there isn't a docstring on a node that can have one, we report it:
We use getattr() here because ast.Module nodes don't have names:
The module, Greeter class, and the Greeter class's methods all lack docstrings:
>>> import ast
>>> import contextlib
>>> from pathlib import Path
>>>
>>> tree = ast.parse(Path('snippets/greet.py').read_text())
>>> for node in ast.walk(tree):
... with contextlib.suppress(TypeError):
... if not ast.get_docstring(node):
... print(getattr(node, 'name', 'module'))
module
Greeter
__init__
greet
The ast.walk() function yields the nodes in no specific order, so we don't have context beyond the node itself. In the case of the previous exercise, larger files can easily make the results confusing. Furthermore, we may want to flag missing docstrings on the __init__() method only if the class doesn't have one. For these use cases, we need the context provided by traversing the tree in a specific order.
Depth-first traversal
The ast module provides two classes that perform depth-first traversal of an AST:
ast.NodeVisitor: visits nodes in an ASTast.NodeTransformer: special version of the above that can also modify nodes
ast.NodeVisitor
When we subclass ast.NodeVisitor, we create visit_<NodeType>() methods for each AST node we want to visit, and the ast.NodeVisitor will take care of calling them as nodes of that type are encountered.
Suppose we want to check our code for the following try/except/pass anti-pattern like the following code from the try_except.py snippet:
def strip_password(x: dict[str, str]) -> None:
try:
del x['password']
except KeyError:
pass
Instead, we want to encourage the use of contextlib.suppress():
import contextlib
def strip_password(x: dict[str, str]) -> None:
with contextlib.suppress(KeyError):
del x['password']
We need to visit each ast.Try node and inspect its handlers – if there is only one handler and its body ends with an ast.Pass node then we will report it:
import ast
class TryExceptVisitor(ast.NodeVisitor):
def visit_Try(self, node):
if len(node.handlers) == 1 and isinstance(
node.handlers[0].body[-1], ast.Pass
):
print(
'try/except/pass block on line',
f'{node.lineno}, use contextlib.suppress',
)
examples/try_except_visitor_1.py
To use our visitor, we instantiate it and call its visit() method, passing in the AST to start the traversal:
>>> from pathlib import Path
>>>
>>> source_code = Path('snippets/try_except.py').read_text()
>>> tree = ast.parse(source_code)
>>> visitor = TryExceptVisitor()
>>> visitor.visit(tree)
try/except/pass block on line 3, use contextlib.suppress
We aren't done yet though. The visit_Try() method is currently cutting off the traversal to descendants of ast.Try nodes, meaning our visitor never visits nested try blocks (only the outermost one). Consider this example of nested try blocks from the try_except_nested.py snippet, where we want to detect the anti-pattern in the inner try:
def strip_password(x: dict[str, str]) -> None:
try:
print(f'Received dict with keys: {x.keys()}')
try:
del x['password']
except KeyError:
pass
except Exception as e:
raise TypeError('Invalid input, expected dict') from e
The TryExceptVisitor doesn't find anything with this input because it doesn't go any deeper after it visits the outermost try:
>>> source_code = Path(
... 'snippets/try_except_nested.py'
... ).read_text()
>>> tree = ast.parse(source_code)
>>> visitor = TryExceptVisitor()
>>> visitor.visit(tree)
Partial AST traversal of the
try_except_nested.py snippet with the initial TryExceptVisitor visualized with Graphviz.
The generic_visit() method
When we don't define a dedicated visit_<NodeType>() method for an AST node, the ast.NodeVisitor calls the generic_visit() method, which continues the traversal. The visit_Try() method we defined does not currently call generic_visit() on that node, so the traversal does not go any deeper.
We need to call generic_visit() ourselves. Note the indentation level – it is outside of the if because we want to visit all nodes, regardless of whether their ancestors had the issue we are looking for:
import ast
class TryExceptVisitor(ast.NodeVisitor):
def visit_Try(self, node):
if len(node.handlers) == 1 and isinstance(
node.handlers[0].body[-1], ast.Pass
):
print(
'try/except/pass block on line',
f'{node.lineno}, use contextlib.suppress',
)
self.generic_visit(node)
examples/try_except_visitor_2.py
The TryExceptVisitor now visits the innermost try and detects the issue:
>>> tree = ast.parse(source_code)
>>> visitor = TryExceptVisitor()
>>> visitor.visit(tree)
try/except/pass block on line 5, use contextlib.suppress
Full AST traversal of the
try_except_nested.py snippet visualized with Graphviz.
Exercise 3
Create a GenericExceptionVisitor class that detects both bare except blocks and the usage of generic Exceptions. Your visitor will need to visit both ast.Raise and ast.ExceptionHandler nodes. You can test it using the source code in the generic_exception.py snippet, which has multiple variations of what we want to detect.
try:
del x['non_existent_key']
except: # bare except
raise Exception('No such key') # generic Exception
Bonus: If you have time, use the ast.get_source_segment() function to print any problematic code you detect.
Example solution
In addition to ast, we will also use textwrap for text formatting:
We start by inheriting from ast.NodeVisitor:
In order to use ast.get_source_segment(), we need the source code string:
_print_source_segment() will print the source code we reference:
ast.get_source_segment() needs the full source code and the AST node:
visit_Raise() defines our actions when we encounter ast.Raise nodes:
Here, we look for raise Exception:
or raise Exception(...):
If either is true, we print the issue, line number, and the code itself for reference:
Regardless of whether we found something, we make sure to continue the traversal:
We also need to visit ast.ExceptHandler nodes:
With a bare except, there is no exception type provided in node.type:
With a generic exception, the exception type provided is Exception:
Again, regardless of whether we found something, we continue the traversal:
We generated the AST in __init__(), so we create run() to call visit():
import ast
from textwrap import dedent, indent
class GenericExceptionVisitor(ast.NodeVisitor):
def __init__(self, source_code):
self.source_code = source_code
self.tree = ast.parse(source_code)
def _print_source_segment(self, node):
code_segment = ast.get_source_segment(
self.source_code, node, padded=True
)
print(indent(dedent(code_segment), '| '), end='\n\n')
def visit_Raise(self, node):
match node.exc:
case (
ast.Name(id='Exception')
| ast.Call(func=ast.Name(id='Exception'))
):
print(
'Generic Exception raised on line',
f'{node.lineno}:',
)
self._print_source_segment(node)
self.generic_visit(node)
def visit_ExceptHandler(self, node):
match node.type:
case None:
print(f'Bare except on line {node.lineno}:')
self._print_source_segment(node)
case ast.Name(id='Exception'):
print(
f'Generic Exception on line {node.lineno}:'
)
self._print_source_segment(node)
self.generic_visit(node)
def run(self):
self.visit(self.tree)
examples/generic_exception_visitor.py
Using the GenericExceptionVisitor is a little different. This time, we pass in the source code when we initialize it, and we call the run() method to kick off the traversal:
>>> from pathlib import Path
>>>
>>> source_code = Path(
... 'snippets/generic_exception.py'
... ).read_text()
>>> visitor = GenericExceptionVisitor(source_code)
>>> visitor.run()
Bare except on line 4:
| except:
| pass
Generic Exception on line 11:
| except Exception:
| pass
Generic Exception raised on line 17:
| raise Exception('Improper input format')
Generic Exception raised on line 22:
| raise Exception
Bare except on line 28:
| except:
| print('Shame on you!')
| raise
ast.NodeTransformer
The ast.NodeTransformer performs the traversal in the same way that the ast.NodeVisitor does, but it can modify the AST. So far, our visit_*() methods haven't returned anything (implicitly, they all returned None). However, with the ast.NodeTransformer, the return value modifies the AST:
- Returning
Nonedeletes the subtree rooted at that node (i.e., that node and all of its descendants) - Returning
transformed_nodereplaces the subtree rooted at the visited node withtransformed_node(or keeps it if it wasn't modified)
Circling back to our try/except/pass detector, we can create an ast.NodeTransformer to rewrite that code to use contextlib.suppress() instead of just suggesting it:
def strip_password(x: dict[str, str]) -> None:
try:
del x['password']
except KeyError:
pass
↓
import contextlib
def strip_password(x: dict[str, str]) -> None:
with contextlib.suppress(KeyError):
del x['password']
Once again, we will use the ast module along with textwrap:
We start by inheriting from ast.NodeTransformer:
has_changed indicates whether we need to add an import for contextlib:
_get_suppress_block() will take an ast.Try node and convert it:
Rather than write the AST directly, we will write source code, parse it, then edit it:
The ast.With node we need is stored in the body of an ast.Module node:
We suppress() the exception from the except (when it's not a bare except):
The body of the with block will be the code that was in the body of the try:
Finally, we return this new node so that we can update the AST:
The ast.NodeTransformer will call our visit_Try() method during traversal:
First, we get the updates from all descendant nodes:
If we need to rewrite a try block, we will report it and start the AST update:
We call _get_suppress_block() to get the new node and track the change:
By returning the new node, the try block is replaced by the new with block:
Notice that we don't have to specify the indentation level – the AST handles this:
Finally, we create a run() method as the entry point:
We start by calling visit() to traverse the entire AST:
If any edits were made, we will add import contextlib to the top of the module:
We return the modified AST with all the location information required to compile:
import ast
from textwrap import dedent
class TryExceptTransformer(ast.NodeTransformer):
def __init__(self, source_code):
self.tree = ast.parse(source_code)
self.has_changed = False
def _get_suppress_block(self, node):
suppress_example = dedent("""
with contextlib.suppress(Exception):
pass
""")
with_block = ast.parse(suppress_example).body[0]
if exc_type := node.handlers[0].type:
with_block.items[0].context_expr.args = [exc_type]
with_block.body = node.body
return with_block
def visit_Try(self, node):
node = self.generic_visit(node)
if len(node.handlers) == 1 and isinstance(
node.handlers[0].body[-1], ast.Pass
):
print(
'Detected a try/except/pass block on',
f'line {node.lineno}, rewriting',
)
node = self._get_suppress_block(node)
self.has_changed = True
return node
def run(self):
self.tree = self.visit(self.tree)
if self.has_changed:
self.tree.body = [
ast.Import([ast.alias('contextlib')])
] + self.tree.body
return ast.fix_missing_locations(self.tree)
examples/try_except_transformer.py
We can use the TryExceptTransformer on the try_except.py snippet to generate the modified AST. Remember that using ast.unparse() may result in other changes to the code, such as the loss of comments and formatting:
>>> from pathlib import Path
>>>
>>> source_code = Path('snippets/try_except.py').read_text()
>>> transformer = TryExceptTransformer(source_code)
>>> updated_ast = transformer.run()
Detected a try/except/pass block on line 2, rewriting
>>> print(ast.unparse(updated_ast))
import contextlib
def strip_password(x: dict[str, str]) -> None:
with contextlib.suppress(KeyError):
del x['password']
Note that, in order to simplify the code, we didn't check if there was already an import of contextlib or even the suppress() function, but you could do that as well.
Exercise 4
Create an ast.NodeTransformer to add placeholder messages to all assert calls that don't have them. We did this earlier with ast.walk(), and this will look very similar. We want to visit the ast.Assert nodes and check for the presence of a message (available in the msg attribute). Don't forget to return the node after visiting it, or it will be removed from the tree.
Example solution
import ast
class AssertTransformer(ast.NodeTransformer):
def visit_Assert(self, node):
node = self.generic_visit(node)
if not node.msg:
node.msg = ast.Constant('TODO: Add failure info')
node = ast.fix_missing_locations(node)
return node
examples/assert_transformer.py
Building an Import Linter
Managing context during traversal
Up until this point, we were only concerned with a node and its immediate children, but sometimes we need to understand a node's ancestry (i.e., parents, grandparents, etc.), which, since nodes don't hold references to their parents, is not possible without some extra accounting on our end. One such case is taking into account what is in scope (variables, imports, etc.) when processing nodes.
In this section, we will learn how to track this information as we build an import linter capable of flagging unused imports, along with masked and missing names.
Finding all imports
Let's start by finding all the imports in a module with a new ImportVisitor:
We start by importing ast and inheriting from ast.NodeVisitor:
We will track imports in the imports_available list:
We will be visiting both ast.Import and ast.ImportFrom nodes:
In both cases, we will track the import in the same way, so we add a helper method:
Each import statement can import multiple names, so we loop over them:
We will ignore any from x import * imports as a simplification:
For each named import, we track its name and alias (if any), along with its module:
We use getattr() here because this is only for ast.ImportFrom nodes:
We add the imports extracted from the node to imports_available:
As we have seen before, we call generic_visit() to continue the traversal:
Last, we have the initial version of our run() method, which just calls visit():
import ast
class ImportVisitor(ast.NodeVisitor):
def __init__(self, source_code):
self.source_code = source_code
self.tree = ast.parse(source_code)
self.imports_available = []
def _visit_import(self, node):
self.imports_available.extend(
[
{
'import': alias.name,
'from': getattr(node, 'module', None),
'alias': alias.asname,
}
for alias in node.names
if alias.name != '*'
]
)
self.generic_visit(node)
def visit_Import(self, node):
self._visit_import(node)
def visit_ImportFrom(self, node):
self._visit_import(node)
def run(self):
self.visit(self.tree)
checkpoints/initial.py
Let's try this out on the imports.py snippet, which has one import of each case we need to handle. Notice we have module-level imports and imports inside functions:
import json
from contextlib import suppress
def strip_password(x):
with suppress(KeyError):
del x['password']
def dump_info(x, out):
json.dump(strip_password(x), out)
def analyze_something(x):
import pandas as pd
df = pd.DataFrame(x)
Our ImportVisitor finds each of the imports:
>>> from pathlib import Path
>>> source_code = Path('snippets/imports.py').read_text()
>>> visitor = ImportVisitor(source_code)
>>> visitor.run()
>>> print(visitor.imports_available)
[{'import': 'json', 'from': None, 'alias': None},
{'import': 'suppress', 'from': 'contextlib', 'alias': None},
{'import': 'pandas', 'from': None, 'alias': 'pd'}]
Tracking import scope
In order to flag missing names and unused imports, we need to know which imports are available to us. However, right now, we don't have the full story – we need to account for their scope. For example, the import of json on line 2, is only available within the scope of the get_data() function, which is narrower than the module scope, in which we call json.dump() on line 8:
def get_data():
import json
return json.loads('{"key": "value"}')
data = get_data()
# this results in a NameError
json.dump(data, 'data.json')
Using a stack to track scope
Let's update our ImportVisitor to track import scope:
Our __init__() method now initializes a stack for tracking the ancestry:
Each time we visit an import node, we will now record the scope:
A node's scope is the path to it from the root of the tree:
We will override the generic_visit() method to keep the stack up-to-date:
Each time we visit a node that has a body attribute, our scope changes:
We call the superclass's generic_visit() method to continue the traversal:
Afterward, we remove the node from the stack (remember, this is depth first):
import ast
class ImportVisitor(ast.NodeVisitor):
def __init__(self, source_code):
self.source_code = source_code
self.tree = ast.parse(source_code)
self.imports_available = []
self.stack = []
def _visit_import(self, node):
import_scope = '.'.join(self.stack)
self.imports_available.extend(
[
{
'scope': import_scope,
'import': alias.name,
'from': getattr(node, 'module', None),
'alias': alias.asname,
}
for alias in node.names
if alias.name != '*'
]
)
self.generic_visit(node)
def visit_Import(self, node):
self._visit_import(node)
def visit_ImportFrom(self, node):
self._visit_import(node)
def generic_visit(self, node):
if hasattr(node, 'body'):
self.stack.append(getattr(node, 'name', 'module'))
super().generic_visit(node)
if hasattr(node, 'body'):
self.stack.pop()
def run(self):
self.visit(self.tree)
checkpoints/stack.py
Now the ImportVisitor includes the scope in which each of the imports can be used, and we are one step closer to detecting missing names and unused imports:
>>> visitor = ImportVisitor(source_code)
>>> visitor.run()
>>> print(visitor.imports_available)
[{'scope': 'module',
'import': 'json', 'from': None, 'alias': None},
{'scope': 'module',
'import': 'suppress', 'from': 'contextlib', 'alias': None},
{'scope': 'module.analyze_something',
'import': 'pandas', 'from': None, 'alias': 'pd'}]
What's currently in scope?
As we explore the AST, we need to be able to determine which names (e.g., imports, variables, function definitions) are in scope. A name is in scope if its definition scope matches the current scope or is a prefix of it:
| definition scope | current scope | is in scope? |
|---|---|---|
module |
module |
True |
module |
module.function |
True |
module |
module.function.with |
True |
module.function |
module |
False |
module.function |
module.function.with |
True |
Exercise 5
Starting from checkpoints/exercise_5.py, write the following methods for the ImportVisitor class to add the functionality to determine which imports are in scope given the current state of the stack during traversal:
_is_in_scope(self, definition_scope: str) -> bool, which given an import's scope (definition_scope) will return whether it is currently in scope (using the stack)get_in_scope_import(self, name: str) -> dict | None, which will filter theimports_availablelist down to the import ofnamethat is currently in scope by calling_is_in_scope()and breaking ties by selecting the narrowest scope (e.g.,module.xis narrower thanmodule)
Note: We will be using _is_in_scope() later for checking whether other names are in scope, so don't pass in the full import dictionary inside imports_available.
Example solution
The definition_scope is in scope if the stack starts with that path. We slice the stack and compare the lists for equality instead of comparing strings to avoid any false-positives (e.g., module.a is a substring of module.abc, but they have different scopes):
def _is_in_scope(self, definition_scope: str) -> bool:
check_scope = definition_scope.split('.')
return self.stack[: len(check_scope)] == check_scope
The get_in_scope_import() method uses _is_in_scope() to filter imports:
First, we grab all imports of name that are in scope:
For aliased imports, we only compare name to that alias:
If nothing is in scope, we return None:
Otherwise, we take the narrowest scope (more dots, means deeper in the stack):
def get_in_scope_import(self, name: str) -> dict | None:
scoped_imports = [
import_info
for import_info in self.imports_available
if self._is_in_scope(import_info['scope'])
and name
== (import_info['alias'] or import_info['import'])
]
if not scoped_imports:
return None
return max(
scoped_imports, key=lambda x: x['scope'].count('.')
)
Tracking name definitions
As alluded to before, in order to see if an import is missing, we also need to track all the names used and the scopes in which they were defined. Imports, class definitions, function definitions, function arguments, and variable are all names. Think about what happens if you try to run ast.parse() without first running import ast – you get a NameError:
>>> ast.parse('x = 1')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ast.parse('x=1')
^^^
NameError: name 'ast' is not defined. Did you forget to import 'ast'?
We will use a defaultdict to track names, where the key is the name we find in the source code we are processing, and the value is a list of dictionaries that each contain the scope upon declaration, the type of name it is (e.g., builtin, import, etc.), and the line number (if not a builtin, like dict, sum, or KeyError):
import ast
import builtins
from collections import defaultdict
class ImportVisitor(ast.NodeVisitor):
def __init__(self, source_code: str) -> None:
self.source_code = source_code
self.tree = ast.parse(source_code)
self.stack = []
self.imports_available = []
self.names_defined = defaultdict(list)
for builtin in builtins.__dict__.keys():
self.names_defined[builtin].append(
{
'scope': 'module',
'type': 'builtin',
'line_number': None,
}
)
...
checkpoints/exercise_6.py
Exercise 6
Starting from checkpoints/exercise_6.py, update the ImportVisitor to include name tracking for imports (ast.Import and ast.ImportFrom), class definitions (ast.ClassDef), function definitions (ast.FunctionDef and ast.AsyncFunctionDef), function arguments (ast.arg), and variables (visit assignments by visiting ast.Name when ctx is of type ast.Store). Note that we will be ignoring the ast.Del context on ast.Name nodes to keep things simple.
Bonus: If you have time, print out a warning whenever a name is redefined within a given scope, for example:
>>> dict = {} # this masks the builtin dict()
>>> my_dict = dict(x=1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
my_dict = dict(x=1)
TypeError: 'dict' object is not callable
Example solution
Let's update our ImportVisitor to track names and report name masking:
The _track_name_definition() method updates names_defined:
And will make a call to _flag_if_masked() each time (we will come back to this):
We are already processing imports, but we need to track the names now:
If the import was aliased, that will be the name, otherwise it is the import name:
Here, we track names from variable assignments (note the ast.Store context):
Rather than writing several similar visit_*() methods, we override visit():
This replaces the visit_Import() and visit_ImportFrom() methods:
Here, we handle class and function declarations, as well as function arguments:
The only difference is how we extract the name itself:
For all other nodes, we preserve the superclass's behavior:
Now, coming back to flagging masked names upon redefinition:
For it to be a redefinition, we need at least two occurrences of the name:
If there is indeed a redefinition, we store the latest one since this is depth-first:
Every other definition that is in scope needs to be flagged as masked:
Builtins aren't explicitly defined so we don't have a line number:
Print the warning, e.g., builtin dict is masked by Name of the same name on line 1:
import ast
import builtins
from collections import defaultdict
class ImportVisitor(ast.NodeVisitor):
...
def _flag_if_masked(self, name):
if len(definitions := self.names_defined[name]) < 2:
return
latest_def = definitions[-1]
# mark all others still in scope as masked
for older_def in definitions[:-1]:
if self._is_in_scope(older_def['scope']):
other_line_number = (
f' on line {older_def["line_number"]}'
if older_def['line_number'] is not None
else ''
) # empty for builtins only
print(
f'{older_def["type"]} {name}{older_line}',
f'is masked by the {latest_def["type"]}',
'of the same name',
f'on line {latest_def["line_number"]}',
)
def _track_name_definition(self, node, name):
self.names_defined[name].append(
{
'scope': '.'.join(self.stack),
'type': node.__class__.__name__,
'line_number': node.lineno,
}
)
self._flag_if_masked(name)
def _visit_import(self, node):
import_scope = '.'.join(self.stack)
self.imports_available.extend(
[
{
'scope': import_scope,
'import': alias.name,
'from': getattr(node, 'module', None),
'alias': alias.asname,
}
for alias in node.names
if alias.name != '*'
]
)
for alias in node.names:
self._track_name_definition(
node, alias.asname or alias.name
)
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
self._track_name_definition(node, node.id)
self.generic_visit(node)
def generic_visit(self, node):
if hasattr(node, 'body'):
# we have entered a new scope
self.stack.append(getattr(node, 'name', 'module'))
super().generic_visit(node)
if hasattr(node, 'body'):
self.stack.pop()
def visit(self, node):
if isinstance(node, ast.Import | ast.ImportFrom):
self._visit_import(node)
elif isinstance(
node,
ast.ClassDef
| ast.FunctionDef
| ast.AsyncFunctionDef
| ast.arg,
):
self._track_name_definition(
node,
node.arg
if isinstance(node, ast.arg)
else node.name,
)
self.generic_visit(node)
else:
super().visit(node)
...
checkpoints/exercise_7.py
Exercise 7
We are now ready to detect missing name definitions and unused imports. Missing name definitions can be detected during the AST traversal, but unused imports will have to be checked at the end (after we have counted the number of times each import is used). Starting from checkpoints/exercise_7.py, make the following changes to the ImportVisitor to add this functionality:
- Update
visit_Name()to handle theast.Loadcontext. Here, you should flag missing name definitions. - Track the number of times an import is accessed (not defined), and use this information to flag unused imports after the traveral has finished.
Example solution
Let's go over the changes to the ImportVisitor:
The first change is in the _visit_import() method:
We now track the import's line number and the number of times it was accessed:
In the visit_Name() method, we now handle the ast.Load context:
The ast.Load context means we accessed the name:
First, we check if the name is in scope (remember, this includes imports):
If not, we report that the name is missing:
Otherwise, we grab the narrowest scope of that import (None, if it's not an import):
If there is an in-scope import, we increment the number of times it was accessed:
We handle flagging unused imports in the run() method:
First, we perform the full traversal of the AST:
Then, we check which imports haven't been accessed and flag them:
import ast
import builtins
from collections import defaultdict
class ImportVisitor(ast.NodeVisitor):
...
def _visit_import(self, node):
import_scope = '.'.join(self.stack)
self.imports_available.extend(
[
{
'scope': import_scope,
'import': alias.name,
'from': getattr(node, 'module', None),
'alias': alias.asname,
'times_accessed': 0,
'line_number': node.lineno,
}
for alias in node.names
if alias.name != '*'
]
)
for alias in node.names:
self._track_name_definition(
node, alias.asname or alias.name
)
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
self._track_name_definition(node, node.id)
elif isinstance(node.ctx, ast.Load):
if not (
any(
self._is_in_scope(name_info['scope'])
for name_info
in self.names_defined[node.id]
)
):
print(
f'Missing definition for {node.id}',
f'on line {node.lineno}'
)
elif import_of_name := self.get_in_scope_import(
node.id
):
import_of_name['times_accessed'] += 1
self.generic_visit(node)
...
def run(self):
self.visit(self.tree)
for import_info in self.imports_available:
if not import_info['times_accessed']:
print(
f'Unused import {import_info["import"]}',
f'on line {import_info["line_number"]}',
)
checkpoints/final.py
Potential next steps
As far as this workshop is concerned, we are done with the ImportVisitor, but, if you would like more practice, it can still be extended further. You can use the version found in checkpoints/final.py as a starting point for further enhancements:
- Account for deleting names with
del(this is theast.Delcontext) - Handle
from x import * - Suggest imports when names are missing like Python does for certain
NameErrors - Remove the unused imports by converting it to an
ast.NodeTransformer - Flag and remove duplicate imports
Related content
My PyCon Lithuania 2025 keynote "Build Your Own (Simple) Static Code Analyzer" is another introduction to using ASTs, but it focuses on generating docstrings from type annotations and signatures.
Thank you!
I hope you enjoyed the session. You can follow my work on these platforms: