Call for backward compatability to enable users to understand and adapt to pytorch compiler

In the last few months, with the help of @jansel @ezyang @Chillee , I developed a tool for users to understand and adapt to pytorch compiler torch.compile.

It reveals the working internals of Dynamo and Inductor, so that users can understand what does pytorch compiler torch.compile do to their code, and they can change their code so that torch.compile can work better.

For example, the following code can produce many output artifacts:

import torch

@torch.compile(backend="inductor")
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

import depyf
with depyf.prepare_debug(toy_example, "./dump_src_debug_function_aot"):
    for _ in range(100):
        toy_example(torch.randn(10, requires_grad=True), torch.randn(10, requires_grad=True))
        toy_example(torch.randn(8, requires_grad=True), torch.randn(8, requires_grad=True))

with depyf.debug():
    toy_example(torch.randn(10, requires_grad=True), torch.randn(10, requires_grad=True))

These artifacts include:

__compiled_fn_0 AFTER POST GRAD 0.py __compiled_fn_4 Captured Graph 0.py  __compiled_fn_8 Captured Graph 0.py
__compiled_fn_0 Backward graph 0.py  __compiled_fn_4 Forward graph 0.py   __compiled_fn_8 Forward graph 0.py
__compiled_fn_0 Captured Graph 0.py  __compiled_fn_4 Joint graph 0.py     __compiled_fn_8 Joint graph 0.py
__compiled_fn_0 Forward graph 0.py   __compiled_fn_4 kernel 0.py          __compiled_fn_8 kernel 0.py
__compiled_fn_0 Joint graph 0.py     __compiled_fn_4 kernel 1.py          __compiled_fn_9 AFTER POST GRAD 0.py
__compiled_fn_0 kernel 0.py          __compiled_fn_7 AFTER POST GRAD 0.py __compiled_fn_9 AFTER POST GRAD 1.py
__compiled_fn_3 AFTER POST GRAD 0.py __compiled_fn_7 AFTER POST GRAD 1.py __compiled_fn_9 Backward graph 0.py
__compiled_fn_3 Backward graph 0.py  __compiled_fn_7 Backward graph 0.py  __compiled_fn_9 Captured Graph 0.py
__compiled_fn_3 Captured Graph 0.py  __compiled_fn_7 Captured Graph 0.py  __compiled_fn_9 Forward graph 0.py
__compiled_fn_3 Forward graph 0.py   __compiled_fn_7 Forward graph 0.py   __compiled_fn_9 Joint graph 0.py
__compiled_fn_3 Joint graph 0.py     __compiled_fn_7 Joint graph 0.py     __compiled_fn_9 kernel 0.py
__compiled_fn_3 kernel 0.py          __compiled_fn_7 kernel 0.py          __compiled_fn_9 kernel 1.py
__compiled_fn_4 AFTER POST GRAD 0.py __compiled_fn_7 kernel 1.py          full_code.py
__compiled_fn_4 AFTER POST GRAD 1.py __compiled_fn_8 AFTER POST GRAD 0.py
__compiled_fn_4 Backward graph 0.py  __compiled_fn_8 Backward graph 0.py

They reveal details of:

  • Dynamo transformed bytecode, with decompiled source code, guards (in full_code.py)
  • captured graph, joint graph, forward graph, backward graph from AOT Autograd (in __compiled_fn_{n} {graph name}.py)
  • lowered and compiled kernel from inductor (in __compiled_fn_{n} {kernel}.py)
  • (dynamic) shape information of each tensor (in __compiled_fn_{n} {graph name}.py)

It works for three backends: "eager"/"aot_eager"/"inductor". For each backend, we can set breakpoints in corresponding files, and use debugger to step through the code:

  • "eager" backend usually ends in __compiled_fn_{n} Captured Graph 0.py.
  • "aot_eager" backend usually ends in __compiled_fn_{n} Forward graph 0.py and __compiled_fn_{n} Backward graph 0.py.
  • "inductor" backend usually ends in __compiled_fn_{n} kernel 0.py.

Plus: I also tried to reveal the details of "inductor" (lowering, decomposition, and kernel fusion plan), but they seem quite intricate.

Since it interacts with pytorch compiler’s internal details, it also relies on many implementation details of pytorch. Therefore, I want to discuss if the pytorch team can ensure backward compatibility of some internal details.

Currently, the tool relies on the following internal details of pytorch:

  • bytecode hook registration API torch._dynamo.convert_frame.register_bytecode_hook, I use it to decompile transformed bytecode.
  • torch._dynamo.eval_frame.innermost_fn and torch._dynamo.eval_frame._debug_get_cache_entry_list can extract cache entries from compiled function.
  • all the guarding conditions are stored in code_parts attribute, in python’s source code format.
  • current compiled function name relies on torch._dynamo.bytecode_transformation._unique_id_counter. I cannot use torch._dynamo.bytecode_transformation.unique_id function because it will increase the counter.
  • compiled functions are named __compiled_fn_{next(_unique_id_counter)}, resume functions are named __resume_xxx.
  • I hijack torch.fx.graph_module._exec_with_source so that fx graph’s forward function has source code in files and can be stepped-through by debuggers.
  • I replace torch._dynamo.utils.lazy_format_graph_code.__code__ to another code object, so that I can capture all related fx graphs.
  • I hijack torch._inductor.codecache.PyCodeCache.load_by_key_path so that I can get triton/openmp code and the call function for each compiled CPU/GPU kernel.
  • I hijack torch.fx.Interpreter.boxed_run so that it runs forward function, rather than running fx graph node by node (used for aot_eager backend).

How stable are these internal details are?

For APIs like torch._dynamo.convert_frame.register_bytecode_hook / torch._dynamo.eval_frame.innermost_fn / torch._dynamo.eval_frame._debug_get_cache_entry_list, I’m quite confident that they should remain stable. (But not that confident, as these APIs are very private, with many leading underscores)

For implementation details like code_parts of guards, and torch._dynamo.bytecode_transformation._unique_id_counter for __compiled_fn and __resume, I suppose they will be kept as conventions. But not quite sure.

For hijacked functions like torch.fx.graph_module._exec_with_source / torch._dynamo.utils.lazy_format_graph_code / torch._inductor.codecache.PyCodeCache.load_by_key_path / torch.fx.Interpreter.boxed_run , I need their function signature to be backward compatible, and their functionality remains unchanged. E.g. fx graphs’ forward functions are compiled via _exec_with_source, important graphs are logged with lazy_format_graph_code, inductor generated kernels are produced by load_by_key_path. These might easily break.

The purpose of this post, is to raise the awareness of pytorch team that someone uses these internal details. It would be better if some internal details can be turned into stable APIs.

1 Like

Adding backwards compatibility contracts comes with a price. They can slow down the development velocity of PyTorch. ML is an fast moving field, and we need to be nimble to keep up. If I had to pick between delivering more performance to users versus breaking BC on some compiler internals, I’d pick the performance. That said, I think many of the things you mentioned are unlikely to change. One exception is the guards related stuff, which @anijain2305 is working on moving to C++.

We have a similar dependency on CPython. TorchDynamo relies on many of the implementation details of CPython bytecode, but CPython bytecode changes in every version of Python and there is zero BC contract. If we were to ask the Python maintainers to make bytecode backwards compatible, it would basically freeze development of CPython and make many types of improvements impossible. So instead, we have work required to add TorchDynamo support for each new Python version. This is work we signed up for by relying on internal details of CPython.

I’d suggest adding CI to the depyf repo that automatically tests PyTorch releases and nightly builds for compatibility. That will give you early signal if something changes requiring and update.

2 Likes

Yes I just added CI to test with pytorch nightly build in depyf . I will come back to pytorch if any new commits in nightly breaks depyf, and see if we can have a workaround that does not break depyf. Thanks, Jason!