A proposal to make Dynamo more understandable to users

Hi, team, I’m trying to make Dynamo more understandable to users, so that they can debug and see if Dynamo is doing what they want.

I opened this discussion to document the progress and the problems I met now.

The first step is to expose cache entries to users, which has been done in PR1 and PR2. Now users can use torch._dynamo.eval_frame._debug_get_cache_entry_list to retrieve cache entries from functions optimized by Dynamo. The doc shows the usage example.

The second step is to explain the artifacts generated by Dynamo to users. They mainly exist in python bytecode, which are difficult for users to read and understand.

There are four categories of artifacts generated by Dynamo: guard/code/compiled partial graph/resume functions. Let’s explain them step by step.

  • Guards: Dynamo actually has the source code for guards, so we can expose the source code of guards to users, so that they don’t have to read bytecodes. This can be achieved by an ongoing PR. After that, we can use guard.src to show its source code to users.

  • Compiled partial graphs: We focus on understanding and debugging Dynamo, so we can use an “eager” backend, then compiled partial graphs are generated by torch.fx, with readable source code.

  • Resume functions: these functions are parts of original functions, and maybe we can just print the original source code to users, and tell them where to resume for this function.

  • Compiled code: this is the function to assemble compiled partial graphs and resume functions into the final code. I assume that this function’s bytecode is not very complicated, and maybe my simple decompiler will work.

The final goal might look like this:

from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable

@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))

After we compile the function, we can describe what Dynamo does for users:

torchdynamo.describe(toy_example)

The desired output:

This is a function optimized by Dynamo. It has {n} cache entries.

Cache Entry 1:

Guard:

{Guard Code}

Code:

{Decompiled Source Code}

There are {m} subgraphs in the function:

SubGraph 1: __compiled_fn_{i}
Source code of subgraph function 1:

There are {p} resume functions:

Resume function 1: __resume_at_{offset}_{j}

def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:  <== resume in this line
        b = b * -1
    return x * b

I’m seeking feedback from this forum to see if the proposal is valuable and worth investing.

I have encountered several confusing things:

1: the function before and after compiled:

from typing import List
import torch
from torch import _dynamo as torchdynamo
counter = 1
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    def compiled_fn(*args, **kwargs):
        global counter
        counter += 1
        return gm.forward(*args, **kwargs)
    return compiled_fn

def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

opt_toy_example = torchdynamo.optimize(my_compiler)(toy_example)
value_before = counter
for _ in range(100):
    opt_toy_example(torch.randn(10), torch.randn(10)) # calls the compiled code, and modifies the `counter`
assert counter > value_before

value_before = counter
toy_example(torch.randn(10), torch.randn(10)) # not calls the compiled code
assert counter == value_before

from torch._dynamo.eval_frame import _debug_get_cache_entry_list
print(f"original function has {len(_debug_get_cache_entry_list(toy_example.__code__))} cache entries") # prints 1
print(f"optimized function has {len(_debug_get_cache_entry_list(opt_toy_example.__code__))} cache entries") # prints 0

It seens that the original function toy_example holds the cache entries, while only calling the function after optimization opt_toy_example can run the optimized code.

After digging into the code, I find the relationship between toy_example and opt_toy_example in eval_frame.py:

        def _fn(*args, **kwargs):
            if (
                not isinstance(self, DisableContext)
                and torch.fx._symbolic_trace.is_fx_tracing()
            ):
                if config.error_on_nested_fx_trace:
                    raise RuntimeError(
                        "Detected that you are using FX to symbolically trace "
                        "a dynamo-optimized function. This is not supported at the moment."
                    )
                else:
                    return fn(*args, **kwargs)

            on_enter()
            prior = set_eval_frame(callback)
            backend_ctx = backend_ctx_ctor()
            backend_ctx.__enter__()
            dynamic_ctx = enable_dynamic(self.dynamic, self.export)
            dynamic_ctx.__enter__()
            try:
                return fn(*args, **kwargs)
            finally:
                set_eval_frame(prior)
                dynamic_ctx.__exit__(None, None, None)
                backend_ctx.__exit__(None, None, None)

Here toy_example is fn, and opt_toy_example is _fn.

It seems opt_toy_example enables the special eval frame hooks, and then run toy_example, right?

2: Why does the compiled graph function also have this context switching stuff?

Since opt_toy_example enables the special eval frame hooks, dynamo will find cache entries in toy_example, and run the code if guarding conditions are satisfied. The guarded code is:

def toy_example(a, b):
    __temp_1 = __compiled_fn_0(a, b)
    x = __temp_1[0]
    if __temp_1[1]:
        return __resume_at_30_1(b, x)
    return __resume_at_38_2(b, x)

What puzzles me is: the __compiled_fn_0 is not the compiled function, but the compiled function is stored in __compiled_fn_0._torchdynamo_orig_callable. We know __compiled_fn_0._torchdynamo_orig_callable is a compiled graph function, and the guarding conditions of toy_example are already satisfied. Why do we still need to enable the special eval frame hooks for running __compiled_fn_0._torchdynamo_orig_callable? Are there any cases that we further install guards and compiled codes to the compiled graph function? If not, maybe we should directly run the compiled graph function without switching eval frame hooks.

3. It seems resume functions also have cache entries.

I tried to decompile the bytecode of resume functions, but find that they actually have cache entries too. Resume functions have their own guards and compiled code. This means decompiling its bytecode makes no sense. We need to decompile the cached code, and only need to inform users where does this function resume.

4. The @functools.wraps(fn) in eval_frame.py causes much trouble for understanding Dynamo

I finally find that both the optimized function and __compiled_fn_0 stuff are returned function from this line. The @functools.wraps(fn) makes inspection very misleading: the bytecode and inspect.getsource differs. I think removing @functools.wraps(fn) would be better.

It would be best if we can make it to be a class like OptimizedFunction, and prints some information for its str and repr.

In Summary

If we wire things up successfully, we only need to decompile bytecode in cache entries. When a function has no cache entries, inspect.getsource can work smoothly (except for guards, we have to use guard.src to see its source code).

Help needed for clarifying the above confusing points

Since I’m pretty new to Dynamo, I would like to consult the team for confirming the above details. Your help is desperately needed!

1 Like
  1. _fn is a wrapper created by @torch.compile() to enable torchdynamo. It doesn’t modify the original function, just enables dynamo before calling it and disables it after.

  2. I believe the inner wrapper is to disable dynamo (the opposite of 1), this is so we don’t try to run dynamo on the backend compiler.

  3. Resume functions are treated exactly the same any other function. Dynamo just runs recursively on them. This trick simplifies dynamo’s implementation, it just handles the first subgraph then recursively calls torch.compile() to lazily handle the resume functions.

  4. I can see why functools.wrap is causing you confusion. Though I think for most usecases functools.wrap is actually what we want and is good.

@torch.compile
def descriptive_function_name():
   """" very useful docstring """
    ...

print(descriptive_function_name.__name__)
print(descriptive_function_name.__doc__)

This example will print out descriptive_function_name and very useful docstring. If you removed functools.wrap, it would print out dynamo internals.

For end users, we want dynamo to be transparent. Most people aren’t trying to examine the internals of dynamo like you are.

Thanks for the explanation. With these clarification, I got much better understanding of Dynamo.

I appreciate the efforts of Dynamo to be transparent to end users. However, I think there are other factors that requires consideration:

  1. Too much transparency would make it look like black magic and might intimidate users, therefore hindering the adoption of PyTorch 2.0. As people are trying PyTorch 2.0, comments in social media are accumulating, in terms of both praising PyTorch 2.0 to be fast and criticizing PyTorch to silently do things users not want. I’m working for a project on my lab, and want to persuade collaborators to use PyTorch 2.0 (that’s why I recently dived so deep into Dynamo). But at first, I have to convince them that PyTorch 2.0 is sound to gain their trust. Speed is good, but only after users trust it and use it.

  2. Some large projects (e.g. I work for mmcv and tianshou) require much more than @torch.compile to migrate to PyTorch 2.0. And they have to be compatible with both PyTorch 2.0 and PyTorch 1.x. Debugging functionality is essential to help them migrate to PyTorch 2.0.

  3. I think giving more information to users never hurt. For example, if we wire things up and do not wrap function names, the following output would be better in my opinion. We don’t need to present dynamo internals, but just a nice information telling users this function runs under dynamo control.

@torch.compile
def descriptive_function_name():
   """" very useful docstring """
    ...
print(descriptive_function_name)
# DynamoRunner(function=<descriptive_function_name>, context=<DynamoContext>)

To sum up: we have to trade off between transparency and exposing internel details. However, absolute transparency is usually not the best choice. I believe I’m not alone.