Supporting Dynamo in Python 3.11 - NULL

This is the start of a series covering some of the more difficult implementation details encountered while supporting Dynamo in Python 3.11. This post covers challenges resulting from changes in CPython revolving around the usage of NULL in the function call sequence.

(NOTE: we have bytecode examples in this post for illustration purposes only - they may not perfectly depict what is generated today.)

TL;DR

CPython made changes to the function call sequence at the bytecode level. In particular, most function calls now require a C NULL to be present on the stack below the callable. NULL can be pushed onto the stack at the beginning of the function call sequence (before the function is loaded), or at the end (before the function is called). Dynamo supports this new requirement through the use of additional push_null arguments in some codegen functions.

If you are developing Dynamo and need to touch codegen, you need to be aware of this new push_null behavior. In most cases, you should push NULL immediately before loading a function you wish to call in the future, either by creating a PUSH_NULL instruction directly (make sure to check that the runtime Python version is >= 3.11) or by using the push_null argument in the codegen function that generates the loading instruction (if this argument exists). If this is not possible, then you can push NULL immediately before calling the function using the push_null argument in create_call_function.

Background

What is CPython bytecode?

Reference

CPython implements a stack machine, operated on by CPython bytecode, which is compiled from Python source.

Consider the following function:

def foo(x, y):
    z = x + y
    print(z)
    return a

This function compiles to the following bytecode in Python 3.11.3 using dis.dis:

  1           0 RESUME                   0

  2           2 LOAD_FAST                0 (x)
              4 LOAD_FAST                1 (y)
              6 BINARY_OP                0 (+)
             10 STORE_FAST               2 (z)

  3          12 LOAD_GLOBAL              1 (NULL + print)
             24 LOAD_FAST                2 (z)
             26 PRECALL                  1
             30 CALL                     1
             40 POP_TOP

  4          42 LOAD_FAST                2 (z)
             44 RETURN_VALUE

Here’s a brief description of what the instructions do:

# Push locals `x` and `y` to the stack
  2           2 LOAD_FAST                0 (x)
              4 LOAD_FAST                1 (y)
# Pop the first two values of the stack (`x`, `y`), add them, push the result back (`x+y`)
              6 BINARY_OP                0 (+)
# Pop the top value of the stack (`x+y`) and store it into local `z`.
             10 STORE_FAST               2 (z)
# Push NULL, then the global `print`, then the local `z`.
  3          12 LOAD_GLOBAL              1 (NULL + print)
             24 LOAD_FAST                2 (z)
# Pop 2 + 1 values, make a function call with 1 argument (i.e. `print(z)`) and push the result (`None`).
             26 PRECALL                  1
             30 CALL                     1
# Pop the top of the stack, since we don't use it.
             40 POP_TOP
# Push the local `z` and end the function, returning the value at the top of the stack (`z`).
  4          42 LOAD_FAST                2 (z)
             44 RETURN_VALUE

How does Dynamo use CPython bytecode?

Dynamo is reliant on bytecode for:

  1. Extracting the computation graph of torch ops
  2. Generating modified bytecode
  3. Generating continuation functions

Extracting the computation graph

Dynamo extracts computation graphs (represented by FX graphs) from PyTorch programs by analyzing Python bytecode given by PEP 523. When a function call is performed, PEP 523 passes the function’s code object (bytecode, variable names, function name, etc.) and context (arguments, locals, globals, closure variables, etc.) to Dynamo. Because Dynamo is given both code and context, we can effectively simulate running the function. Through the use of FakeTensor (objects that behave like torch.Tensor, but do not contain data) and Variables (represents values on the CPython stack), Dynamo can keep track of variable types throughout execution and detect PyTorch operations that are called, without having to actually execute any of the expensive operations. The file symbolic_trace.py defines the bulk of Dynamo’s bytecode simulator. Note in particular the methods of InstructionTranslatorBase and its derived classes that are named after CPython bytecode opnames (e.g. LOAD_FAST, STORE_FAST). These methods define how Dynamo simulates each bytecode operation. Dynamo’s just-in-time bytecode analysis approach differs from TorchScript, which analyzes the abstract syntax tree before execution.

Generating modified bytecode

How are Dynamo-extracted FX (computation) graphs actually executed? First the FX graph is passed to a backend compiler, which returns a Python callable bound to a compiled (C++/Triton/etc.) kernel. Dynamo then generates modified bytecode that calls the callable returned by the backend compiler. Finally, Dynamo returns a new code object with the modified bytecode, which the CPython interpreter then executes normally.

As a concrete example, consider the following code:

import torch

@torch.compile(backend="eager")
def f(x, y):
    z = x + y
    return torch.relu(z)

f(torch.randn(3, 3), torch.randn(3, 3))

Running with TORCH_LOGS="bytecode", we can see that the modified bytecode given by Dynamo to CPython to execute is

  3           0 RESUME                   0
              2 PUSH_NULL
              4 LOAD_GLOBAL              4 (__compiled_fn_0)
             16 LOAD_FAST                0 (x)
             18 LOAD_FAST                1 (y)
             20 PRECALL                  2
             24 CALL                     2
             34 UNPACK_SEQUENCE          1
             38 RETURN_VALUE

where __compiled_fn_0 is the callable returned by the backend compiler.

We will see shortly that this level of indirection is necessary to support important Dynamo features.

Generating continuation functions

When Dynamo encounters an unsupported operation during bytecode simulation, such as a data-dependent conditional jump (e.g. POP_JUMP_FORWARD_IF_TRUE when the top of stack value is a tensor) or a call to a built-in function that causes side effects (e.g. print), it performs what is known as a graph break: Dynamo compiles and clears the accumulated torch operations (“breaking the computation graph”), runs the unsupported operation, then continues analyzing the remaining bytecode.

A continuation function is a function that can begin at any arbitrary instruction of the original function. Continuation functions are relevant in the implementation of graph breaks. In particular, a continuation function

  • restores the stack, variables, and contexts to the state following the instruction we breaked on and
  • jumps to the target instruction in the original function that we should continue from.

Any values that the continuation function needs in order to restore execution state are provided to it through function arguments.

Consider the following example:

import torch

v = 0
w = 0

@torch.compile(backend="eager")
def f(x, y):
    global v, w
    v = 1
    with torch.no_grad():
        z = x + y
        w = 2
        a = torch.sin(
            print("hello") or z
        )
        return torch.relu(a)

f(torch.randn(3, 3), torch.randn(3, 3))

The original bytecode of f is:

  6           0 RESUME                   0

  9           2 LOAD_CONST               1 (1)
              4 STORE_GLOBAL             0 (v)

 10           6 LOAD_GLOBAL              3 (NULL + torch)
             18 LOAD_ATTR                2 (no_grad)
             28 PRECALL                  0
             32 CALL                     0
             42 BEFORE_WITH
             44 POP_TOP

 11          46 LOAD_FAST                0 (x)
             48 LOAD_FAST                1 (y)
             50 BINARY_OP                0 (+)
             54 STORE_FAST               2 (z)

 12          56 LOAD_CONST               2 (2)
             58 STORE_GLOBAL             3 (w)

 13          60 LOAD_GLOBAL              3 (NULL + torch)
             72 LOAD_ATTR                4 (sin)

 14          82 LOAD_GLOBAL             11 (NULL + print)
             94 LOAD_CONST               3 ('hello')
             96 PRECALL                  1
            100 CALL                     1
            110 JUMP_IF_TRUE_OR_POP      1 (to 114)
            112 LOAD_FAST                2 (z)

 13     >>  114 PRECALL                  1
            118 CALL                     1
            128 STORE_FAST               3 (a)

 16         130 LOAD_GLOBAL              3 (NULL + torch)
            142 LOAD_ATTR                6 (relu)
            152 LOAD_FAST                3 (a)
            154 PRECALL                  1
            158 CALL                     1

 10         168 SWAP                     2
            170 LOAD_CONST               0 (None)
            172 LOAD_CONST               0 (None)
            174 LOAD_CONST               0 (None)
            176 PRECALL                  2
            180 CALL                     2
            190 POP_TOP
            192 RETURN_VALUE
        >>  194 PUSH_EXC_INFO
            196 WITH_EXCEPT_START
            198 POP_JUMP_FORWARD_IF_TRUE     4 (to 208)
            200 RERAISE                  2
        >>  202 COPY                     3
            204 POP_EXCEPT
            206 RERAISE                  1
        >>  208 POP_TOP
            210 POP_EXCEPT
            212 POP_TOP
            214 POP_TOP
            216 LOAD_CONST               0 (None)
            218 RETURN_VALUE

The co_varnames (names of arguments and locals) of the original function is ('x', 'y', 'z', 'a'). co_argcount is 2, signifying that the arguments names are x and y.

A graph break occurs at offset 100, since we attempt to call print. At the time of the graph break, the stack contains:

  • (top) The string 'hello'
  • The function print
  • NULL
  • The function torch.sin
  • NULL
  • The exit function for the with torch.no_grad() statement.

When we run the unsupported operation, the top 3 values will be popped and the result of print("hello"), which is None, will be pushed to the top of the stack.

The (annotated) bytecode of the continuation function is

 14           0 RESUME                   0
# Restore with torch.no_grad():
# Result is that the exit function for the context will be on the stack
              2 LOAD_FAST                0 (___stack0)
              4 LOAD_CONST               4 (False)
              6 PUSH_NULL
              8 SWAP                     3
             10 SWAP                     2
             12 PRECALL                  1
             16 CALL                     1
             26 BEFORE_WITH
             28 POP_TOP
# Push the 2nd from the bottom NULL
             30 PUSH_NULL
# Push torch.sin
             32 LOAD_FAST                1 (___stack1)
# Push None, the result of print("hello")
             34 LOAD_FAST                2 (___stack2)
# Jump to the instruction immediately following the instruction we graph breaked on, at offset 202
             36 JUMP_FORWARD            82 (to 202)
# Handle cleanup for with torch.no_grad()
             38 NOP
             40 LOAD_CONST               0 (None)
             42 LOAD_CONST               0 (None)
             44 LOAD_CONST               0 (None)
             46 PRECALL                  2
             50 CALL                     2
             60 POP_TOP
             62 JUMP_FORWARD            11 (to 86)
        >>   64 PUSH_EXC_INFO
             66 WITH_EXCEPT_START
             68 POP_JUMP_FORWARD_IF_TRUE     4 (to 78)
             70 RERAISE                  2
        >>   72 COPY                     3
             74 POP_EXCEPT
             76 RERAISE                  1
        >>   78 POP_TOP
             80 POP_EXCEPT
             82 POP_TOP
             84 POP_TOP
        >>   86 NOP
             88 LOAD_CONST               0 (None)
             90 RAISE_VARARGS            1
# The bytecode below is the original function's bytecode
             92 RESUME                   0
             94 LOAD_CONST               1 (1)
             96 STORE_GLOBAL             0 (v)
             98 LOAD_GLOBAL              3 (NULL + torch)
            110 LOAD_ATTR                2 (no_grad)
            120 PRECALL                  0
            124 CALL                     0
            134 BEFORE_WITH
            136 POP_TOP
            138 LOAD_FAST                4 (x)
            140 LOAD_FAST                5 (y)
            142 BINARY_OP                0 (+)
            146 STORE_FAST               3 (z)
            148 LOAD_CONST               2 (2)
            150 STORE_GLOBAL             3 (w)
            152 LOAD_GLOBAL              3 (NULL + torch)
            164 LOAD_ATTR                4 (sin)
            174 LOAD_GLOBAL             11 (NULL + print)
            186 LOAD_CONST               3 ('hello')
            188 PRECALL                  1
            192 CALL                     1
        >>  202 JUMP_IF_TRUE_OR_POP      1 (to 206)
            204 LOAD_FAST                3 (z)

 13     >>  206 PRECALL                  1
            210 CALL                     1
            220 STORE_FAST               6 (a)

 16         222 LOAD_GLOBAL              3 (NULL + torch)
            234 LOAD_ATTR                6 (relu)
            244 LOAD_FAST                6 (a)
            246 PRECALL                  1
            250 CALL                     1

 10         260 SWAP                     2
            262 LOAD_CONST               0 (None)
            264 LOAD_CONST               0 (None)
            266 LOAD_CONST               0 (None)
            268 PRECALL                  2
            272 CALL                     2
            282 POP_TOP
            284 RETURN_VALUE
            286 PUSH_EXC_INFO
            288 WITH_EXCEPT_START
            290 POP_JUMP_FORWARD_IF_TRUE     4 (to 300)
            292 RERAISE                  2
        >>  294 COPY                     3
            296 POP_EXCEPT
            298 RERAISE                  1
        >>  300 POP_TOP
            302 POP_EXCEPT
            304 POP_TOP
            306 POP_TOP
            308 LOAD_CONST               0 (None)
            310 RETURN_VALUE

The co_varnames of the continuation function is ('___stack0', '___stack1', '___stack2', 'z', 'x', 'y', 'a'). co_argcount is 4, signifying that the arguments names are __stack0, __stack1, __stack2, and z.

The modified bytecode does a number of additional things (this is why the level of indirection is needed). During a graph break, the modified bytecode:

  • calls the backend-compiled function
  • reconstructs the variables that Dynamo has on its simulated stack, and pushes them into the real stack
  • performs any side effects and updates any locals
  • restores, if necessary, contexts at the time of the graph break
  • runs the unsupported operation
  • pushes locals into the stack and makes a call to the continuation function with the entire stack as arguments, and
  • returns the result.

The (annotated) modified bytecode of the above example:

  6           0 RESUME                   0
# Run the backend-compiled function
              2 PUSH_NULL
              4 LOAD_GLOBAL             18 (__compiled_fn_0)
             16 LOAD_FAST                0 (x)
             18 LOAD_FAST                1 (y)
             20 PRECALL                  2
             24 CALL                     2
             34 STORE_FAST               4 (___graph_out_0)
# Restore the stack at the time of the graph break, before the unsupported function
             36 LOAD_GLOBAL             14 (__import_torch)
             48 LOAD_ATTR                8 (set_grad_enabled)
             58 PUSH_NULL
             60 LOAD_GLOBAL              2 (torch)
             72 LOAD_ATTR                4 (sin)
             82 PUSH_NULL
             84 LOAD_GLOBAL             10 (print)
             96 LOAD_CONST               3 ('hello')
# Perform side effects and update locals
             98 LOAD_FAST                4 (___graph_out_0)
            100 LOAD_CONST               4 (0)
            102 BINARY_SUBSCR
            112 LOAD_CONST               1 (1)
            114 LOAD_CONST               2 (2)
            116 STORE_GLOBAL             3 (w)
            118 STORE_GLOBAL             0 (v)
            120 STORE_FAST               2 (z)
# Restore contexts active at the time of the graph break
            122 LOAD_GLOBAL             14 (__import_torch)
            134 LOAD_ATTR                8 (set_grad_enabled)
            144 LOAD_CONST               5 (False)
            146 PUSH_NULL
            148 SWAP                     3
            150 SWAP                     2
            152 PRECALL                  1
            156 CALL                     1
            166 STORE_FAST               5 (___context_manager_0_1)
            168 LOAD_FAST                5 (___context_manager_0_1)
            170 LOAD_METHOD             10 (__enter__)
            192 PRECALL                  0
            196 CALL                     0
            206 POP_TOP
            208 NOP
# Run the unsupported instructions
            210 PRECALL                  1
            214 CALL                     1
# Clean up contexts that we set up
            224 NOP
            226 LOAD_FAST                5 (___context_manager_0_1)
            228 LOAD_METHOD             11 (__exit__)
            250 LOAD_CONST               0 (None)
            252 COPY                     1
            254 COPY                     1
            256 PRECALL                  3
            260 CALL                     3
            270 POP_TOP
            272 JUMP_FORWARD            28 (to 330)
        >>  274 PUSH_EXC_INFO
            276 LOAD_FAST                5 (___context_manager_0_1)
            278 LOAD_METHOD             11 (__exit__)
            300 LOAD_CONST               0 (None)
            302 COPY                     1
            304 COPY                     1
            306 PRECALL                  3
            310 CALL                     3
            320 POP_TOP
            322 RERAISE                  0
        >>  324 COPY                     3
            326 POP_EXCEPT
            328 RERAISE                  1
# Load locals, call the continuation function with the entire stack
        >>  330 NOP
            332 SWAP                     2
            334 SWAP                     3
            336 LOAD_CONST               6 (<function PyCodegen.pop_null.<locals>.<lambda> at 0x7f0b14b66660>)
            338 PRECALL                  0
            342 CALL                     0
            352 POP_TOP
            354 PUSH_NULL
            356 SWAP                     4
            358 SWAP                     3
            360 SWAP                     2
            362 LOAD_GLOBAL             24 (__resume_at_110_2)
            374 SWAP                     4
            376 SWAP                     3
            378 SWAP                     2
            380 LOAD_FAST                2 (z)
            382 PRECALL                  4
            386 CALL                     4
            396 RETURN_VALUE

What’s up with NULL?

You may have noticed the presence of NULL through the PUSH_NULL or the LOAD_GLOBAL instructions in the bytecode samples above. NULL refers to C NULL, not Python None. This is an important distinction because None is a valid Python object while NULL is not.

PUSH_NULL is a new instruction introduced in Python 3.11 that pushes a C NULL to the top of the stack. LOAD_GLOBAL has slightly changed behavior in 3.11, where it can optionally push a NULL to the stack before pushing the requested global value. These changes were introduced in order to make the instruction sequence for calling functions consistent between methods and non-methods.

For method calls, the function call sequence expects these 2 + nargs values to be on top of the stack, where nargs is the number of positional and named arguments:

  • unbound method
  • self
  • remaining positional arguments
  • (top) named arguments.

For non-method calls, the top 2 + nargs stack values are expected to be:

  • NULL
  • function
  • positional arguments
  • (top) named arguments.

For example:

def g(f):
    f(1)
    print(1)
    # assume foo is an object with method bar
    foo.bar(1)

gives this (annotated) bytecode:

 20           0 RESUME                   0

# non-method call on non-global function
 21           2 PUSH_NULL
              4 LOAD_FAST                0 (f)
              6 LOAD_CONST               1 (1)
              8 PRECALL                  1
             12 CALL                     1
             22 POP_TOP

# non-method call on global function
 22          24 LOAD_GLOBAL              1 (NULL + print)
             36 LOAD_CONST               1 (1)
             38 PRECALL                  1
             42 CALL                     1
             52 POP_TOP

# method call
 24          54 LOAD_GLOBAL              2 (foo)
# loads unbound foo.bar, then foo
             66 LOAD_METHOD              2 (bar)
             88 LOAD_CONST               1 (1)
             90 PRECALL                  1
             94 CALL                     1
            104 POP_TOP
            106 LOAD_CONST               0 (None)
            108 RETURN_VALUE

Compare to Python 3.10.12:

# non-method call on non-global function
  2           0 LOAD_FAST                0 (f)
              2 LOAD_CONST               1 (1)
              4 CALL_FUNCTION            1
              6 POP_TOP

# non-method call on global function
  3           8 LOAD_GLOBAL              0 (print)
             10 LOAD_CONST               1 (1)
             12 CALL_FUNCTION            1
             14 POP_TOP

# method call
  5          16 LOAD_GLOBAL              1 (foo)
             18 LOAD_METHOD              2 (bar)
             20 LOAD_CONST               1 (1)
             22 CALL_METHOD              1
             24 POP_TOP
             26 LOAD_CONST               0 (None)
             28 RETURN_VALUE

The CALL_FUNCTION_EX instruction expects the top of the stack to be slightly different, but it still requires a NULL to be on the stack below the callable.

Challenges and workarounds

When should Dynamo push NULL?

The first challenge is determining when Dynamo should codegen a PUSH_NULL instruction (or a LOAD_GLOBAL instruction that additionally pushes NULL).
The function call sequence begins when an instruction loads the callable to the stack. It ends when a CALL or CALL_FUNCTION_EX instruction runs on the callable.
When a NULL is required, CPython does not care when it is pushed to the stack – it just needs to be present at the right place in the stack at the end of the call sequence.
And so we can see two ways of getting NULL to the right place before the callable – one way that pushes NULL at the beginning of the sequence, and the other at the end of the sequence:

  1. We can PUSH_NULL before loading the function (or LOAD_GLOBAL with NULL), like what CPython generates. However, if we do this, we need to know that the loaded object will be called eventually.
  2. We can PUSH_NULL and move the NULL to the right place in the stack using a sequence of SWAPs, immediately before generating PRECALL + CALL. If we do this, we need to ensure that there is not already a NULL before the callable in the stack.

As an example of the second way, suppose we wish to call f(x, y). Then we can generate:

LOAD_FAST 0 (f)
LOAD_FAST 1 (x)
LOAD_FAST 2 (y)
PUSH_NULL
SWAP 4
SWAP 3
SWAP 2
# stack is now NULL, f, x, y
PRECALL 2
CALL 2

Now the beginning and the end of the function call sequence can happen in Dynamo-generated bytecode, or in user bytecode. If both occur in the same bytecode, then there is no concern. User bytecode is assembled from source and so the NULL pushing logic is expected to already be correct. For functions that are loaded and called by Dynamo-generated bytecode, the Dynamo developer simply needs to either push NULL at the beginning or at the end of the function call sequence with the above ways. We added push_null arguments to several codegen functions so that developers can accomplish this. Some codegen functions with push_null arguments will push NULL at the beginning of the call sequence (such as create_load_global), while others will push at the end of the call sequence (such as create_call_function). It should be fairly easy to tell if a codegen function with push_null pushes at the beginning or at the end of the call sequence. Generally, developers should try to push NULL at the beginning of the call sequence in order to avoid the additional SWAP sequence.

Consider the case where Dynamo bytecode may call a function loaded from user bytecode. This can only happen if Dynamo generates a function using original user bytecode (i.e., in most cases, continuation functions) and control jumps from the user bytecode section to the Dynamo-generated section. This only happens when contexts are cleaned up, but if we look at the context cleanup bytecode that Dynamo generates, we only see calls on Dynamo-pushed context functions. And so it is actually never the case that Dynamo bytecode will call a function loaded from user bytecode.

And finally, consider the case where user bytecode calls a function loaded from Dynamo bytecode. This happens in continuation functions when Dynamo reconstructs the stack – Dynamo may load a function in the modified bytecode that the continuation function’s user bytecode will eventually call. This case is handled well because when Dynamo simulates bytecode and the stack, any encountered instructions that push NULL will result in Dynamo pushing a NullVariable to the stack. So when reconstructing the stack, Dynamo knows exactly where NULLs are in the stack, and so we do not need to push any additional NULLs. That is, if a function on the stack that Dynamo reconstructs requires a NULL, that NULL should be on the stack, and so Dynamo will also reconstruct that NULL.

In summary, there should never be a case in Dynamo where we don’t know whether we should push NULL or not. Every function call has a corresponding NULL push and vice versa. The main thing to note for developers is the additional push_null argument in several codegen functions. The following procedure should be used to correctly generate bytecode that pushes NULL properly:

  1. Find where in the Dynamo source code the function is loaded to the stack, and where it is called.
  2. If the codegen function that generates the load instruction has a push_null argument, set it to True.
  3. Otherwise, manually create a PUSH_NULL instruction (you need to check that the Python version is 3.11+, i.e., sys.version_info >= (3, 11)). Consider adding a push_null argument to the codegen function if it is frequently used to load functions.
  4. If for some reason it is not a good idea to push NULL when the function is loaded (e.g. what is loaded might not be called in some cases), then push NULL at the end of the call sequence – set the push_null argument of create_call_function to True.
  5. Make sure to run your new tests in 3.11+, and make sure your tests are correct when graph breaks occur.

Example:

# generates torch.as_tensor(self)
self.extend_output(
    [
        # push_null is True
        self.create_load_python_module(torch, True),
        self.create_load_attr("as_tensor"),
    ]
)
self.extend_output(arg.load(self))
self.extend_output(create_call_function(1, False))

NULLs on the CPython stack

As stated before, NULL differs from None in that the latter is a valid Python object while the former is not. This can cause issues when NULL is present in the stack and CPython attempts to treat it as a valid Python object. For example, we cannot run POP_TOP when the top of stack is NULL because CPython will attempt to decrease the reference count of the top of the stack, resulting in a NULL dereference.

One particularly annoying implementation detail is that when we run CALL, CPython pops nargs + 2 and decrements the reference counts for each object, except for possibly the first (bottommost) value if it is NULL. This is problematic in the case of a graph break, where Dynamo reconstructs the stack in the modified bytecode and passes the entire stack as arguments to the continuation function. If there is a NULL in the stack, then Dynamo would be passing NULL as a function argument in the call sequence, which would result in a NULL dereference.

Thankfully, through the use of NullVariable, we know at the time of a graph break the locations of any NULL in the stack. So we can pop any NULLs on the stack in the modified bytecode before calling the continuation function, and in the continuation function, we can push NULLs back to the correct spots. We can see in the modified bytecode and continuation function examples above that Dynamo indeed generates bytecode that does this.

One interesting issue about popping NULL is that we can’t pop it from the stack normally through the use of the POP_TOP instruction, as described above. A clever workaround is to push a function that does nothing, call the function with 0 arguments (which consumes the NULL), and POP_TOP the result (which is None).

7 Likes