Supporting Dynamo in Python 3.12

With another year comes a new Python version for us to support! Fortunately, enabling Python 3.12 support in Dynamo was not as challenging as supporting Python 3.11 (you can read our technical blog posts on supporting Python 3.11 here and here). 3.11 was particularly difficult because it introduced major changes to frame evaluation and bytecode semantics as part of the Faster CPython effort. 3.12 had fewer such changes, but nonetheless, there were some challenges encountered, which we document in this post.

Frame allocation and deallocation

In Python, function calls are evaluated using frame objects. A frame consists of a code object and an execution context (locals, globals, etc.). Nested function calls are handled using a stack of frame objects.

In Python 3.11, this function is a top-level frame evaluation function. To summarize the code:

  1. Allocate the frame memory and initialize the frame’s contents (cpython/Python/ceval.c at d542a9be51776e8d589363ee15164dec8dbd3a76 · python/cpython · GitHub).
  2. Interpret the bytecode (the interpreter function can be manually set using PEP 523) (cpython/Python/ceval.c at d542a9be51776e8d589363ee15164dec8dbd3a76 · python/cpython · GitHub).
  3. Clean the frame’s contents and deallocate the frame’s memory (cpython/Python/ceval.c at d542a9be51776e8d589363ee15164dec8dbd3a76 · python/cpython · GitHub).

From these observations, we can say that a function’s caller is responsible for both allocating and deallocating the callee frame.

In Python 3.12, the frame evaluation procedure changed slightly:

  1. Allocate the frame memory and initialize the frame’s contents (cpython/Python/ceval.c at 0300e33b223a6cfd691bea186cd413424162d83a · python/cpython · GitHub)
  2. Interpret the bytecode (cpython/Python/ceval.c at 0300e33b223a6cfd691bea186cd413424162d83a · python/cpython · GitHub)

Where did the cleanup code go? It turns out that this code moved to the frame evaluation function (for example). Thus, we observe that in 3.12, the callee is now responsible for cleaning its own frame.

Dynamo runs optimized bytecode by intercepting the original frame, creating a new “shadow” frame with optimized bytecode, and default interpreting the shadow frame. In 3.11, the following is done:

  1. CPython attempts to execute the original frame.
  2. CPython allocates and initializes the original frame.
  3. CPython calls the interpreter function, which is set to Dynamo.
  4. Dynamo attempts to execute the shadow frame.
  5. Dynamo allocates and initializes the shadow frame ([dynamo, 3.12] Allocate Dynamo shadow frames by mimicking CPython by williamwen42 · Pull Request #122146 · pytorch/pytorch · GitHub).
  6. Dynamo calls the default interpreter function on the shadow frame ([dynamo, 3.12] Allocate Dynamo shadow frames by mimicking CPython by williamwen42 · Pull Request #122146 · pytorch/pytorch · GitHub).
  7. Dynamo cleans up the shadow frame ([dynamo, 3.12] Allocate Dynamo shadow frames by mimicking CPython by williamwen42 · Pull Request #122146 · pytorch/pytorch · GitHub).
  8. CPython cleans up the original frame.

In order to respect the new frame allocation/deallocation convention, what needs to happen is:

  1. CPython attempts to execute the original frame.
  2. CPython allocates and initializes the original frame.
  3. CPython calls Dynamo.
  4. Dynamo attempts to execute the shadow frame.
  5. Dynamo allocates and initializes the shadow frame ([dynamo, 3.12] Allocate Dynamo shadow frames by mimicking CPython by williamwen42 · Pull Request #122146 · pytorch/pytorch · GitHub).
  6. Dynamo default interprets the shadow frame ([dynamo, 3.12] Allocate Dynamo shadow frames by mimicking CPython by williamwen42 · Pull Request #122146 · pytorch/pytorch · GitHub).
  7. The shadow frame is cleaned up by the default interpreter.
  8. Dynamo cleans up the original frame ([dynamo, 3.12] Allocate Dynamo shadow frames by mimicking CPython by williamwen42 · Pull Request #122146 · pytorch/pytorch · GitHub).

Note that Dynamo (1) is responsible for freeing memory it did not allocate and (2) allocates memory that it will not free. Because of this, Dynamo must now allocate and deallocate memory in the same way as CPython, meaning that we must do more copy-pasting (allocating memory, deallocating memory).

We also made an initial implementation mistake where we assumed that it would always be safe to use malloc/free. In reality, CPython may choose different memory allocation/deallocation functions, so it is safer to use _PyObject_VirtualAlloc/Free, which we also copy-pasted. The PR is here.

A few reflections on this work:

  1. A lot of these changes/requirements on CPython’s end are not documented well. We had to do a bunch of code reading and documentation digging.
  2. It would be great to introduce some basic PEP 523 tests that mimic Dynamo in a minimal way in order to increase PEP 523 stability between Python versions.
  3. We should consider decreasing or entirely removing our reliance on PEP 523.

Block stack and Exception Tables

Python 3.12 generates bytecode in a way that breaks Dynamo’s implementation of the block stack. However, I realized that I never wrote a post on supporting exception tables and block stack changes in 3.11, so I may as well document that here as well!

In Python 3.10, the block stack was used to keep track of the target instruction to jump to when an exception is encountered. For example, when entering a with/try block, a block pointing to the exception handling bytecode is pushed to the block stack. When we exit the with/try block (either normally or due to an exception), the block is popped.

Python 3.11 removed the block stack and replaced it with an exception table. A code object’s exception table is a jump table - an exception jump target can be specified for each instruction in the code object. Supporting exception tables in Dynamo was by no means an easy task, we had to:

Now Dynamo for the most part attempts to mimic CPython behavior – many times, we directly look at CPython code in order to guide Dynamo implementation. One key feature of Dynamo that is not present in CPython is the graph break. In particular, we note that during a graph break, Dynamo produces modified bytecode and a continuation function. The modified bytecode restores active contexts at the time of the graph break before running the unsupported instruction. The continuation function again restores the context at the beginning of the bytecode. (See here for more details on the graph break process.)

In order to restore active contexts during a graph break, Dynamo must know what the currently active contexts are! In 3.10, the block stack doubled as Dynamo’s way to keep track of active contexts. Since 3.11 removed the block stack, how will Dynamo now keep track of active contexts? The solution was for Dynamo to continue to keep track of the block stack in 3.11.

In 3.11, we simulated the block stack in Dynamo by applying a best-effort heuristic on the exception table.

To motivate our heuristic, here is an example function with nested with blocks and the corresponding 3.11 bytecode:

>>> def fn():
...     with a():
...             with b():
...                     c()
... 
>>> dis.dis(fn)
  1           0 RESUME                   0

  2           2 LOAD_GLOBAL              1 (NULL + a)
             14 PRECALL                  0
             18 CALL                     0
             28 BEFORE_WITH
             30 POP_TOP

  3          32 LOAD_GLOBAL              3 (NULL + b)
             44 PRECALL                  0
             48 CALL                     0
             58 BEFORE_WITH
             60 POP_TOP

  4          62 LOAD_GLOBAL              5 (NULL + c)
             74 PRECALL                  0
             78 CALL                     0
             88 POP_TOP

  3          90 LOAD_CONST               0 (None)
             92 LOAD_CONST               0 (None)
             94 LOAD_CONST               0 (None)
             96 PRECALL                  2
            100 CALL                     2
            110 POP_TOP
            112 JUMP_FORWARD            11 (to 136)
        >>  114 PUSH_EXC_INFO
            116 WITH_EXCEPT_START
            118 POP_JUMP_FORWARD_IF_TRUE     4 (to 128)
            120 RERAISE                  2
        >>  122 COPY                     3
            124 POP_EXCEPT
            126 RERAISE                  1
        >>  128 POP_TOP
            130 POP_EXCEPT
            132 POP_TOP
            134 POP_TOP

  2     >>  136 LOAD_CONST               0 (None)
            138 LOAD_CONST               0 (None)
            140 LOAD_CONST               0 (None)
            142 PRECALL                  2
            146 CALL                     2
            156 POP_TOP
            158 LOAD_CONST               0 (None)
            160 RETURN_VALUE
        >>  162 PUSH_EXC_INFO
            164 WITH_EXCEPT_START
            166 POP_JUMP_FORWARD_IF_TRUE     4 (to 176)
            168 RERAISE                  2
        >>  170 COPY                     3
            172 POP_EXCEPT
            174 RERAISE                  1
        >>  176 POP_TOP
            178 POP_EXCEPT
            180 POP_TOP
            182 POP_TOP
            184 LOAD_CONST               0 (None)
            186 RETURN_VALUE
ExceptionTable:
  30 to 58 -> 162 [1] lasti
  60 to 88 -> 114 [2] lasti
  90 to 112 -> 162 [1] lasti
  114 to 120 -> 122 [4] lasti
  122 to 126 -> 162 [1] lasti
  128 to 128 -> 122 [4] lasti
  130 to 134 -> 162 [1] lasti
  162 to 168 -> 170 [3] lasti
  176 to 176 -> 170 [3] lasti

And here is an example function with consecutive with blocks and the corresponding 3.11 bytecode:

>>> def fn():
...     with a():
...             pass
...     with b():
...             pass
... 
>>> dis.dis(fn)
  1           0 RESUME                   0

  2           2 LOAD_GLOBAL              1 (NULL + a)
             14 PRECALL                  0
             18 CALL                     0
             28 BEFORE_WITH
             30 POP_TOP

  3          32 NOP

  2          34 LOAD_CONST               0 (None)
             36 LOAD_CONST               0 (None)
             38 LOAD_CONST               0 (None)
             40 PRECALL                  2
             44 CALL                     2
             54 POP_TOP
             56 JUMP_FORWARD            11 (to 80)
        >>   58 PUSH_EXC_INFO
             60 WITH_EXCEPT_START
             62 POP_JUMP_FORWARD_IF_TRUE     4 (to 72)
             64 RERAISE                  2
        >>   66 COPY                     3
             68 POP_EXCEPT
             70 RERAISE                  1
        >>   72 POP_TOP
             74 POP_EXCEPT
             76 POP_TOP
             78 POP_TOP

  4     >>   80 LOAD_GLOBAL              3 (NULL + b)
             92 PRECALL                  0
             96 CALL                     0
            106 BEFORE_WITH
            108 POP_TOP

  5         110 NOP

  4         112 LOAD_CONST               0 (None)
            114 LOAD_CONST               0 (None)
            116 LOAD_CONST               0 (None)
            118 PRECALL                  2
            122 CALL                     2
            132 POP_TOP
            134 LOAD_CONST               0 (None)
            136 RETURN_VALUE
        >>  138 PUSH_EXC_INFO
            140 WITH_EXCEPT_START
            142 POP_JUMP_FORWARD_IF_TRUE     4 (to 152)
            144 RERAISE                  2
        >>  146 COPY                     3
            148 POP_EXCEPT
            150 RERAISE                  1
        >>  152 POP_TOP
            154 POP_EXCEPT
            156 POP_TOP
            158 POP_TOP
            160 LOAD_CONST               0 (None)
            162 RETURN_VALUE
ExceptionTable:
  30 to 30 -> 58 [1] lasti
  58 to 64 -> 66 [3] lasti
  72 to 72 -> 66 [3] lasti
  108 to 108 -> 138 [1] lasti
  138 to 144 -> 146 [3] lasti
  152 to 152 -> 146 [3] lasti

In the nested with block example, note that there are no gaps in the exception table for the bytecode spanning the outer block. In the consecutive with block example, note that there are gaps in the exception table between the with block body and cleanup sections, and between the two with blocks. This leads us to the following heuristic:

  • If we encounter BEFORE_WITH, push a block.
  • If the current instruction has the same exception jump target as the top block target, do nothing.
  • If the current instruction has a different exception jump target than the top block target, then (a) if the current jump target is the same as the second highest block target, we moved from an inner block to an outer block, so pop a block. If (b) these targets are not the same, then we moved into a nested block, so push a block.
  • If the current instruction has no exception jump target, we are not in any block. Make sure that we only have one block on the stack (we shouldn’t be jumping out of more than one block at a time), then pop the block.

This heuristic (plus a few modifications) worked for 3.11 but no longer works for 3.12.

Consider an example function with nested try blocks and the corresponding 3.12 bytecode:

>>> def fn():
...     try:
...             a()
...             try:
...                     c()
...             except:
...                     d()
...     except:
...             b()
... 
>>> dis.dis(fn)
  1           0 RESUME                   0

  2           2 NOP

  3           4 LOAD_GLOBAL              1 (NULL + a)
             14 CALL                     0
             22 POP_TOP

  4          24 NOP

  5          26 LOAD_GLOBAL              3 (NULL + c)
             36 CALL                     0
             44 POP_TOP
             46 RETURN_CONST             0 (None)
        >>   48 PUSH_EXC_INFO

  6          50 POP_TOP

  7          52 LOAD_GLOBAL              5 (NULL + d)
             62 CALL                     0
             70 POP_TOP
             72 POP_EXCEPT
             74 RETURN_CONST             0 (None)
        >>   76 COPY                     3
             78 POP_EXCEPT
             80 RERAISE                  1
        >>   82 PUSH_EXC_INFO

  8          84 POP_TOP

  9          86 LOAD_GLOBAL              7 (NULL + b)
             96 CALL                     0
            104 POP_TOP
            106 POP_EXCEPT
            108 RETURN_CONST             0 (None)
        >>  110 COPY                     3
            112 POP_EXCEPT
            114 RERAISE                  1
ExceptionTable:
  4 to 22 -> 82 [0]
  26 to 44 -> 48 [0]
  48 to 70 -> 76 [1] lasti
  72 to 72 -> 82 [0]
  76 to 80 -> 82 [0]
  82 to 104 -> 110 [1] lasti

And consider this example function with consecutive try blocks and the corresponding 3.12 bytecode:

>>> def fn():
...     try:
...             a()
...     except:
...             b()
...     try:
...             c()
...     except:
...             d()
... 
>>> dis.dis(fn)
  1           0 RESUME                   0

  2           2 NOP

  3           4 LOAD_GLOBAL              1 (NULL + a)
             14 CALL                     0
             22 POP_TOP

  6     >>   24 NOP

  7          26 LOAD_GLOBAL              5 (NULL + c)
             36 CALL                     0
             44 POP_TOP
             46 RETURN_CONST             0 (None)
        >>   48 PUSH_EXC_INFO

  4          50 POP_TOP

  5          52 LOAD_GLOBAL              3 (NULL + b)
             62 CALL                     0
             70 POP_TOP
             72 POP_EXCEPT
             74 JUMP_BACKWARD           26 (to 24)
        >>   76 COPY                     3
             78 POP_EXCEPT
             80 RERAISE                  1
        >>   82 PUSH_EXC_INFO

  8          84 POP_TOP

  9          86 LOAD_GLOBAL              7 (NULL + d)
             96 CALL                     0
            104 POP_TOP
            106 POP_EXCEPT
            108 RETURN_CONST             0 (None)
        >>  110 COPY                     3
            112 POP_EXCEPT
            114 RERAISE                  1
ExceptionTable:
  4 to 22 -> 48 [0]
  26 to 44 -> 82 [0]
  48 to 70 -> 76 [1] lasti
  82 to 104 -> 110 [1] lasti

Note that both examples start with

  1           0 RESUME                   0

  2           2 NOP

  3           4 LOAD_GLOBAL              1 (NULL + a)
             14 CALL                     0
             22 POP_TOP

  6     >>   24 NOP

  7          26 LOAD_GLOBAL              5 (NULL + c)
             36 CALL                     0
             44 POP_TOP
             46 RETURN_CONST             0 (None)

ExceptionTable:
  4 to 22 -> 48 [0]
  26 to 44 -> 82 [0]

Indeed, 3.12 rearranged bytecode order so that the fast path appears sequentially in bytecode, improving locality.
However, nested and sequential try blocks are much harder to distinguish from looking at bytecode.
In particular, this breaks our heuristic since we can no longer differentiate between the case where we push a block (nested block) or pop a block, followed by a push (consecutive block).
(Note that some NOP instructions don’t have an exception jump target and thus aren’t part of a block - Dynamo deals with this by not running the heuristic when dealing with NOPs.)

We could attempt to analyze the exception table more carefully, but we found a simpler solution. We noted that Dynamo doesn’t support graph breaks in try blocks - only in with blocks. So in order to graph break properly, we really only need to keep track of with blocks. We can simply ignore every block that was not pushed by a BEFORE_WITH instruction. In effect, our block stack is now really just a stack of currently active contexts. Fortunately, nested and consecutive with blocks are still distinguishable in 3.12. Our new heuristic is thus:

  • If we encounter BEFORE_WITH, push a block.
  • If the current instruction has a different exception jump target than the top block target and it is the same as the second highest block, then pop the block.
  • If the current instruction has no exception jump target, we are not in any block. Make sure that we only have one block on the stack (we shouldn’t be jumping out of more than one block at a time), then pop the block.

Our fixes to the heuristic (which are compatible with 3.11 as well) were made in this PR.

A few reflections on this work:

  1. How bytecode is structured is even more poorly documented than frame evaluation since these are pure CPython implementation details.
  2. Detecting with blocks is brittle and could break in future Python version updates. We are unsure how to remove this dependency since Dynamo at its core is a bytecode interpreter.
  3. Detecting try blocks is quite difficult at the moment. If we decide to support graph breaking in try blocks in the future, we may also need to figure out how to detect try block boundaries from bytecode.

LOAD_SUPER_ATTR bug

PR: [dynamo, 3.12] force LOAD_SUPER_ATTR second bit on by williamwen42 · Pull Request #123686 · pytorch/pytorch · GitHub

This was not a huge blocker in the 3.12 implementation, but it was a funny bug that I encountered.

So I had implemented the LOAD_SUPER_ATTR instruction for 3.12 already but was getting a weird TypeError: super() takes at most 2 arguments (3 given) error in dynamo-wrapped tests.

I was able to minify the repro to:

        class Foo(torch.nn.Sequential):
            def __init__(self, layers):
                torch._dynamo.graph_break()
                super().__init__(*layers)

        def fn(x):
            layers = [torch.nn.Linear(3, 3) for _ in range(3)]
            mod = Foo(layers)
            return mod(x)

        opt_fn = torch.compile(fn, backend="eager")
        opt_fn(torch.randn(3, 3))

This was a confusing error - how was super receiving 3 arguments when according to the code, it was receiving 0? Further, the continuation function bytecode that contained the super call (via LOAD_SUPER_ATTR) looked correct:

torch_dynamo_resume_in___init___at_10326
10326           0 COPY_FREE_VARS           1
                2 RESUME                   0
                4 LOAD_FAST                0 (___stack0)
                6 JUMP_FORWARD            31 (to 70)
                8 COPY_FREE_VARS           1
               10 RESUME                   0
               12 LOAD_GLOBAL              0 (torch)
               22 LOAD_ATTR                2 (_dynamo)
               42 LOAD_ATTR                5 (NULL|self + graph_break)
               62 CALL                     0
          >>   70 POP_TOP

10327          72 LOAD_GLOBAL              7 (NULL + super)
               82 LOAD_DEREF               3 (__class__)
               84 LOAD_FAST                1 (self)
               86 LOAD_SUPER_ATTR         16 (__init__)
               90 LOAD_FAST                2 (layers)
               92 CALL_FUNCTION_EX         0
               94 POP_TOP
               96 RETURN_CONST             0 (None)

I spent hours slightly modifying the source code and playing around with super, inspecting the bytecode, breakpointing in CPython, etc. As I took a break from fixing the crash, I suddenly thought of a function in bytecode_transformation.py: explicit_super(). Why was that function needed?

Then it hit me - the second lowest bit of the LOAD_SUPER_ATTR instruction forces the explicit two-argument super call, but I was simply copying the old bit value. In the original bytecode, it was fine for this bit to be unset, but now that that this instruction is in a continuation function, the implicit zero-argument super call no longer accesses the correct __class__ and self attributes!

Excited, skipped back to my laptop, and forced the LOAD_SUPER_ATTR 2nd-bit to be on, which fixed the bug.

3 Likes