Signalboosting that torch.compile
is now compatible with Python 3.13! You can try it out today with the nightly PyTorch binaries. 3.13 support will be included in the next PyTorch release, 2.6. (Note that we currently still DO NOT support the no-GIL build, Python 3.13t.)
Below, we describe the significant technical challenges that we encountered during Python 3.13 support. Previous posts on torch.compile
/Dynamo support for new Python versions are here:
Manually building frame local dicts
Recall that Dynamo needs access to the frame’s f_locals
dict (we will refer to this as the frame’s “locals”, but note that it actually contains local, cell, and free variables) in order to evaluate guards and to trace bytecode. Frames in Python store their locals in a raw C buffer and this buffer is converted into a Python dict when necessary. In Python 3.10 and below, we could call the C function PyFrame_FastToLocalsWithError
to perform this conversion. Python 3.11 made things difficult by making this function private, so our workaround was to copy-paste this C function and all child functions. This copy-pasted code also mutates the frame’s state, which caused memory leaks, which we addressed.
In 3.13, PEP 667 introduced changes to how a frame’s locals are viewed and generated. In particular, it made PyFrame_FastToLocalsWithError
a no-op - there was no longer a reference implementation in 3.13 and the code we copied from 3.12 was not fully compatible. We thus had to manually write a C++ conversion function from the frame’s locals C buffer to a Python dict. This can be found in torch/csrc/dynamo/framelocals_mapping.[h/cpp]
.
Fortunately, our C++ conversion function is simpler than the previous copy-pasted implementations. In particular, we no longer needed to mutate the state of the frame, which removed the need for our memory leak bug fix from earlier. We backported our manual frame local construction to previous Python versions as well, though we note that frame local dict construction differs somewhat significantly between Python 3.10 and 3.11.
Writing our own locals dict conversion also helped us to understand the layout of the frame locals better. Thus we were able to write utilities for our C/C++ code to access the raw frame locals C buffer without the Python dict conversion. This enabled us to write a guard evaluation optimization where we only convert a frame’s locals to a dict when absolutely necessary.
Relevant PRs:
- [3.12, 3.13, dynamo] simplified construction for frame f_locals/localsplus by williamwen42 · Pull Request #129185 · pytorch/pytorch · GitHub
- [dynamo] switch to get_framelocals_mapping for 3.11 by williamwen42 · Pull Request #139950 · pytorch/pytorch · GitHub
- [dynamo] switch to get_framelocals_mapping for 3.10 and below by williamwen42 · Pull Request #140037 · pytorch/pytorch · GitHub
Function call convention change
TL;DR: If you write Dynamo codegen code, use add_push_null
around bytecode sequences generating callables that will be called.
3.11 added the convention of pushing NULL
before a callable on to the stack. 3.13 swapped the order of the callable and the NULL
, a seemingly simple change that required a fairly significant Dynamo codegen refactor.
Consider the following Dynamo codegen for <= 3.12 (note the following code snippets are representations and do not exactly match the Dynamo format):
# torch.relu(x)
create_load_global("torch", push_null=True)
create_load_attr("relu")
create_instruction("LOAD_FAST", "x")
create_call_function(1, False)
Ideally, we want to push NULL
before we load any other objects - otherwise, we need to push NULL
after loading the callable/arguments and we need to move it to the correct place, costing a number of instructions. Unfortunately, for 3.13, it is not enough to push the NULL
after the LOAD_GLOBAL
- we can only push the NULL
after the entire bytecode sequence that generates the callable has executed - in this case, LOAD_GLOBAL, LOAD_ATTR
:
# torch.relu(x)
create_load_global("torch", push_null=False)
create_load_attr("relu")
create_instruction("PUSH_NULL")
create_instruction("LOAD_FAST", "x")
create_call_function(1, False)
In general, it is not easy to directly transform pre-callable convention bytecode to a post-callable convention; for example:
# NULL pre-callable
# f(x.a).b.c(y.d, z.e)
PUSH_NULL
PUSH_NULL
LOAD_GLOBAL f
LOAD_GLOBAL x
LOAD_ATTR a
CALL 1
LOAD_ATTR b
LOAD_ATTR c
LOAD_GLOBAL y
LOAD_ATTR d
LOAD_GLOBAL z
LOAD_ATTR e
CALL 2
Is there an easy rule to determine where the PUSH_NULL
s should move to for 3.13?
# NULL post-callable
# f(x).a.b(y.c, z.d)
LOAD_GLOBAL f
PUSH_NULL
LOAD_GLOBAL x
LOAD_ATTR a
CALL 1
LOAD_ATTR b
LOAD_ATTR c
PUSH_NULL
LOAD_GLOBAL y
LOAD_ATTR d
LOAD_GLOBAL z
LOAD_ATTR e
CALL 2
In order to generate function calls in a Python-version-agnostic way, we need some notion of marking that a sequence of bytecode is intended to generate a callable. Then we can simply prepend or append PUSH_NULL
to this sequence of bytecode depending on Python version. We introduce the bytecode utility function add_push_null
that accepts a list of bytecode or a function that generates a list of bytecode. A PUSH_NULL
instruction will then be inserted at the right place, depending on Python version. Using add_push_null
in Dynamo codegen essentially signifies that the provided bytecode is intended to be a callable.
For example:
# torch.relu(x)
add_push_null([
# we will call "torch.relu"
create_load_global("torch", push_null=False),
create_load_attr("relu"),
])
create_instruction("LOAD_FAST", "x")
create_call_function(1, False)
Additionally, some bytecode instructions such as LOAD_GLOBAL
and LOAD_ATTR
have a “push NULL
bit” in its argument - if set, NULL
will be pushed to the stack, before or after depending on Python version. Previous codegen API like create_load_global
respects this by providing a push_null
argument. The push_null
argument complicates things because create_load_global(..., push_null=True)
would signify in 3.13 to push NULL
after loading the global, which we have previously established should not happen in general. Our fix was to remove all push_null
arguments from Dynamo codegen functions and let add_push_null
determine if it should set the push NULL
bit of the first or last provided instruction, depending on version.
So the final Dynamo codegen code looks like (note that this code is not dependent on Python version):
# torch.relu(x)
add_push_null([
create_instruction("LOAD_GLOBAL"),
create_instruction("LOAD_ATTR", "relu"),
])
# NULL bit of LOAD_GLOBAL is set on <=3.12
# NULL bit of LOAD_ATTR is set on >=3.13
create_instruction("LOAD_FAST", "x")
create_call_function(1, False)
We had to change almost every place in Dynamo where we codegen a function to use the new add_push_null
API. If you write bytecode generation code in Dynamo, you should use this new API if you codegen functions that will be called.
Bytecode templates
There are a few places in Dynamo where we generate long sequences of hardcoded bytecode. One place in particular is resume_execution.py
, where we have sequences of hardcoded instructions that restore context managers in continuation functions. Different Python versions have different bytecode operations/semantics and can generate different bytecode sequences for the same Python source code. As a result, we need different hardcoded bytecode sequences for different Python versions. This is problematic for several reasons:
- We end up with multiple sequences of hardcoded bytecode that aim to do the same thing - repetition that should be avoided.
- Manually setting up jump/exception table targets is error prone.
- We need to check when updating Python versions that previously generated bytecode is still compatible. If not, we need to add a new case.
We noted that our long sequences of hardcoded bytecode for the most part describe fairly normal Python procedures - such as dictionary iteration, a try
block, or a with
block. We only need to make a few modifications otherwise.
Consider our example of restoring context managers in continuation functions. We need to generate bytecode that does:
context_obj = context_class(context init args)
with context_obj:
(rest of the bytecode)
where context
is the context manager to be restored.
The with
block generates a lot of bytecode that we would like to avoid hardcoding:
def fn():
with x:
y
z
dis.dis(fn)
1 0 RESUME 0
2 2 LOAD_GLOBAL 0 (x)
12 BEFORE_WITH
14 POP_TOP
3 16 LOAD_GLOBAL 2 (y)
26 POP_TOP
2 28 LOAD_CONST 0 (None)
30 LOAD_CONST 0 (None)
32 LOAD_CONST 0 (None)
34 CALL 2
42 POP_TOP
4 44 LOAD_GLOBAL 4 (z)
54 POP_TOP
56 RETURN_CONST 0 (None)
2 >> 58 PUSH_EXC_INFO
60 WITH_EXCEPT_START
62 POP_JUMP_IF_TRUE 1 (to 66)
64 RERAISE 2
>> 66 POP_TOP
68 POP_EXCEPT
70 POP_TOP
72 POP_TOP
4 74 LOAD_GLOBAL 4 (z)
84 POP_TOP
86 RETURN_CONST 0 (None)
>> 88 COPY 3
90 POP_EXCEPT
92 RERAISE 1
ExceptionTable:
14 to 26 -> 58 [1] lasti
58 to 66 -> 88 [3] lasti
We introduce a new bytecode generation utility that can generate bytecode from template functions, bytecode_from_template
. bytecode_from_template
can also
- remap variable names of the template function
- remove prefix instructions (instructions before and including RESUME)
- and replace returns with jumps
so that the generated bytecode can be easily used with other bytecode sequences.
If we apply bytecode_from_template
to our example above:
def _template(ctx, dummy):
with ctx:
dummy
# suppose ctx_var_name == "__stack0"
insts = bytecode_from_template(_template, varname_map={"ctx": ctx_var_name}, noreturn=True, noprefix=True)
print("\n".join(f"{i.opname} {i.arg} {i.argval}" for i in insts))
LOAD_FAST None __stack0
BEFORE_WITH None None
POP_TOP None None
LOAD_FAST None dummy
POP_TOP None None
LOAD_CONST None None
LOAD_CONST None None
LOAD_CONST None None
CALL 2 2
POP_TOP None None
LOAD_CONST None None
JUMP_FORWARD None <class 'torch._dynamo.bytecode_transformation._NotProvided'>
PUSH_EXC_INFO None None
WITH_EXCEPT_START None None
POP_JUMP_IF_TRUE 1 38
RERAISE 2 2
POP_TOP None None
POP_EXCEPT None None
POP_TOP None None
POP_TOP None None
LOAD_CONST None None
JUMP_FORWARD None <class 'torch._dynamo.bytecode_transformation._NotProvided'>
COPY 3 3
POP_EXCEPT None None
RERAISE 1 1
NOP None <class 'torch._dynamo.bytecode_transformation._NotProvided'>
We can see that the generated template bytecode matches fairly well with the actual bytecode generated from a with
block. In order to fill in (rest of the bytecode)
, we can look for the LOAD_FAST dummy
instruction and replace that with the rest of our code. (Code pointer to the actual code this example is based off of)
Applying bytecode_from_template
to resume_execution.py
allowed us to remove the version-dependent hardcoded bytecode and made it easier to understand what kind of bytecode we’re generating. We have ideas for further improving bytecode generation through templates, with the goal of replacing more bytecode generation sites in Dynamo with more natual Pythonic code.
If you’re writing bytecode generation code in Dynamo, consider giving bytecode_from_template
a try!
Relevant PRs: