I build a decompiler to convert bytecode generated by dynamo into readable source code!

Hi, folks, if you are also suffering from reading bytecode generated by dynamo, you can try this out!

Simple usage with dynamo:

First, run a pytorch program with torch.compile:

from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable

@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))

Second, get compiled code and guard code from pytorch:

from torch._dynamo.eval_frame import _debug_get_cache_entry_list
cache_entries = _debug_get_cache_entry_list(toy_example._torchdynamo_orig_callable.__code__)
guard, code = cache_entries[0]

Third, decompile the code to see how the code works:

from depyf import decompile

print("guard code:")
print(decompile(guard))

print("compiled code:")
print(decompile(code))

Output on my computer:

guard code:
def guard(L):
    if not getattr(___guarded_code, 'valid'):
        return False
    else:
        _var0 = L['a']
        __temp_1 = hasattr(_var0, '_dynamo_dynamic_indices')
        if not (__temp_1 == False):
            return False
        else:
            _var1 = L['b']
            __temp_2 = hasattr(_var1, '_dynamo_dynamic_indices')
            if not (__temp_2 == False):
                return False
            else:
                __temp_3 = ___is_grad_enabled()
                if not __temp_3:
                    return False
                else:
                    __temp_4 = ___are_deterministic_algorithms_enabled()
                    if __temp_4:
                        return False
                    else:
                        __temp_5 = ___is_torch_function_enabled()
                        if not __temp_5:
                            return False
                        else:
                            if not (getattr(utils_device, 'CURRENT_DEVICE') == None):
                                return False
                            else:
                                __temp_6 = ___check_tensors(_var0, _var1, tensor_check_names=tensor_check_names)
                                if not __temp_6:
                                    return False
                                else:
                                    return True

compiled code:
def toy_example(a, b):
    __temp_1 = __compiled_fn_0(a, b)
    x = __temp_1[0]
    if __temp_1[1]:
        __temp_2 = __resume_at_30_1(b, x)
        return __temp_2
    else:
        __temp_3 = __resume_at_38_2(b, x)
        return __temp_3

Hopefully, by using this package, you can understand python bytecode now!

2 Likes

@jansel I build this mainly for the doc we write :smile: Is it good to add this package to the doc?

Neat! Yeah, feel free to update the docs.

The code is better and more readable now, I eliminated some temp variables:

guard code:
def guard(L):
    if not getattr(___guarded_code, 'valid'):
        return False
    else:
        _var0 = L['a']
        if not hasattr(_var0, '_dynamo_dynamic_indices') == False:
            return False
        else:
            _var1 = L['b']
            if not hasattr(_var1, '_dynamo_dynamic_indices') == False:
                return False
            elif not ___is_grad_enabled():
                return False
            elif ___are_deterministic_algorithms_enabled():
                return False
            elif not ___is_torch_function_enabled():
                return False
            elif not getattr(utils_device, 'CURRENT_DEVICE') == None:
                return False
            elif not ___check_tensors(_var0, _var1, tensor_check_names=
                tensor_check_names):
                return False
            else:
                return True

compiled code:
def toy_example(a, b):
    __temp_1 = __compiled_fn_0(a, b)
    x = __temp_1[0]
    if __temp_1[1]:
        return __resume_at_30_1(b, x)
    else:
        return __resume_at_38_2(b, x)
1 Like

@jansel PR created at [Doc] Update the dynamo deepdive doc by youkaichao · Pull Request #108147 · pytorch/pytorch · GitHub . And I add support for python 3.7 – 3.11 today. The depyf package should be good to use for understanding dynamo bytecode now. It has almost everything except while loops and for loops, which I suppose rarely occur in dynamo.

1 Like