In the last few months, with the help of @jansel @ezyang @Chillee , I developed a tool for users to understand and adapt to pytorch compiler torch.compile
.
It reveals the working internals of Dynamo and Inductor, so that users can understand what does pytorch compiler torch.compile
do to their code, and they can change their code so that torch.compile
can work better.
For example, the following code can produce many output artifacts:
import torch
@torch.compile(backend="inductor")
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
import depyf
with depyf.prepare_debug(toy_example, "./dump_src_debug_function_aot"):
for _ in range(100):
toy_example(torch.randn(10, requires_grad=True), torch.randn(10, requires_grad=True))
toy_example(torch.randn(8, requires_grad=True), torch.randn(8, requires_grad=True))
with depyf.debug():
toy_example(torch.randn(10, requires_grad=True), torch.randn(10, requires_grad=True))
These artifacts include:
__compiled_fn_0 AFTER POST GRAD 0.py __compiled_fn_4 Captured Graph 0.py __compiled_fn_8 Captured Graph 0.py
__compiled_fn_0 Backward graph 0.py __compiled_fn_4 Forward graph 0.py __compiled_fn_8 Forward graph 0.py
__compiled_fn_0 Captured Graph 0.py __compiled_fn_4 Joint graph 0.py __compiled_fn_8 Joint graph 0.py
__compiled_fn_0 Forward graph 0.py __compiled_fn_4 kernel 0.py __compiled_fn_8 kernel 0.py
__compiled_fn_0 Joint graph 0.py __compiled_fn_4 kernel 1.py __compiled_fn_9 AFTER POST GRAD 0.py
__compiled_fn_0 kernel 0.py __compiled_fn_7 AFTER POST GRAD 0.py __compiled_fn_9 AFTER POST GRAD 1.py
__compiled_fn_3 AFTER POST GRAD 0.py __compiled_fn_7 AFTER POST GRAD 1.py __compiled_fn_9 Backward graph 0.py
__compiled_fn_3 Backward graph 0.py __compiled_fn_7 Backward graph 0.py __compiled_fn_9 Captured Graph 0.py
__compiled_fn_3 Captured Graph 0.py __compiled_fn_7 Captured Graph 0.py __compiled_fn_9 Forward graph 0.py
__compiled_fn_3 Forward graph 0.py __compiled_fn_7 Forward graph 0.py __compiled_fn_9 Joint graph 0.py
__compiled_fn_3 Joint graph 0.py __compiled_fn_7 Joint graph 0.py __compiled_fn_9 kernel 0.py
__compiled_fn_3 kernel 0.py __compiled_fn_7 kernel 0.py __compiled_fn_9 kernel 1.py
__compiled_fn_4 AFTER POST GRAD 0.py __compiled_fn_7 kernel 1.py full_code.py
__compiled_fn_4 AFTER POST GRAD 1.py __compiled_fn_8 AFTER POST GRAD 0.py
__compiled_fn_4 Backward graph 0.py __compiled_fn_8 Backward graph 0.py
They reveal details of:
- Dynamo transformed bytecode, with decompiled source code, guards (in
full_code.py
) - captured graph, joint graph, forward graph, backward graph from AOT Autograd (in
__compiled_fn_{n} {graph name}.py
) - lowered and compiled kernel from inductor (in
__compiled_fn_{n} {kernel}.py
) - (dynamic) shape information of each tensor (in
__compiled_fn_{n} {graph name}.py
)
It works for three backends: "eager"/"aot_eager"/"inductor"
. For each backend, we can set breakpoints in corresponding files, and use debugger to step through the code:
"eager"
backend usually ends in__compiled_fn_{n} Captured Graph 0.py
."aot_eager"
backend usually ends in__compiled_fn_{n} Forward graph 0.py
and__compiled_fn_{n} Backward graph 0.py
."inductor"
backend usually ends in__compiled_fn_{n} kernel 0.py
.
Plus: I also tried to reveal the details of "inductor"
(lowering, decomposition, and kernel fusion plan), but they seem quite intricate.
Since it interacts with pytorch compiler’s internal details, it also relies on many implementation details of pytorch. Therefore, I want to discuss if the pytorch team can ensure backward compatibility of some internal details.
Currently, the tool relies on the following internal details of pytorch:
- bytecode hook registration API
torch._dynamo.convert_frame.register_bytecode_hook
, I use it to decompile transformed bytecode. torch._dynamo.eval_frame.innermost_fn
andtorch._dynamo.eval_frame._debug_get_cache_entry_list
can extract cache entries from compiled function.- all the guarding conditions are stored in
code_parts
attribute, in python’s source code format. - current compiled function name relies on
torch._dynamo.bytecode_transformation._unique_id_counter
. I cannot usetorch._dynamo.bytecode_transformation.unique_id
function because it will increase the counter. - compiled functions are named
__compiled_fn_{next(_unique_id_counter)}
, resume functions are named__resume_xxx
. - I hijack
torch.fx.graph_module._exec_with_source
so that fx graph’sforward
function has source code in files and can be stepped-through by debuggers. - I replace
torch._dynamo.utils.lazy_format_graph_code.__code__
to another code object, so that I can capture all related fx graphs. - I hijack
torch._inductor.codecache.PyCodeCache.load_by_key_path
so that I can get triton/openmp code and thecall
function for each compiled CPU/GPU kernel. - I hijack
torch.fx.Interpreter.boxed_run
so that it runsforward
function, rather than running fx graph node by node (used foraot_eager
backend).
How stable are these internal details are?
For APIs like torch._dynamo.convert_frame.register_bytecode_hook
/ torch._dynamo.eval_frame.innermost_fn
/ torch._dynamo.eval_frame._debug_get_cache_entry_list
, I’m quite confident that they should remain stable. (But not that confident, as these APIs are very private, with many leading underscores)
For implementation details like code_parts
of guards, and torch._dynamo.bytecode_transformation._unique_id_counter
for __compiled_fn
and __resume
, I suppose they will be kept as conventions. But not quite sure.
For hijacked functions like torch.fx.graph_module._exec_with_source
/ torch._dynamo.utils.lazy_format_graph_code
/ torch._inductor.codecache.PyCodeCache.load_by_key_path
/ torch.fx.Interpreter.boxed_run
, I need their function signature to be backward compatible, and their functionality remains unchanged. E.g. fx graphs’ forward
functions are compiled via _exec_with_source
, important graphs are logged with lazy_format_graph_code
, inductor generated kernels are produced by load_by_key_path
. These might easily break.
The purpose of this post, is to raise the awareness of pytorch team that someone uses these internal details. It would be better if some internal details can be turned into stable APIs.