How are guards installed on frames that are transient objects?

Per the video PyTorch 2.0 Live Q&A Series: A Deep Dive on TorchDynamo - YouTube and many other tutorials, I see statements like “guards are installed in the frame”, and the illustration also means this way:

It confuses me a lot: frames in python are transient objects that are created and destroyed when functions are called and returned. How can we install guards on frames?

For example, the following code explicitly saves the frame, and we can see that frames are different across function calls, and saved data does not survive across function calls.

import inspect

frame = None

def f(x):
    global frame
    if frame is None:
        # save some data in the frame object
        frame = inspect.currentframe()
        frame.f_locals['y'] = 5
    else:
        # this is a new frame, and does not remember the data we saved!
        print(inspect.currentframe().f_locals['y'])
    return x + 1

f(2)

# raises exception
f(3)

cc @ezyang as we are discussing in [Fatal Bug] changed nn.Module.training does not trigger recompilation. · Issue #105653 · pytorch/pytorch · GitHub .

Guards are stored using the co_extra field of the PyCodeObject. See PEP 523 – Adding a frame evaluation API to CPython | peps.python.org

So the guards are store on the code not the frame. The picture above is indicating the guards are checking the frame (e.g. guards take the frame as an input), not stored on the frame.

Thanks, that’s nice and clear!
One remaining question is, where does pytorch store the compiled function for graph? Is this https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/eval_frame.c#L289 PyCodeObject* code object?

Another question:
https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/eval_frame.c#L319 the code here explicitly converts the third argument to void*, but the C-API _PyCode_GetExtra mentioned in PEP 523 – Adding a frame evaluation API to CPython | peps.python.org requires void** as an argument. How can this code get compiled? I have to convert the third argument to void** to make my test example compile, but it still crashes in runtime.

My CPP code:

#include <Python.h>
#include <code.h>
#include <frameobject.h>
#include <stdio.h>

typedef struct cache_entry {
  // check the guards: lambda: <locals of user function>: bool
  PyObject* check_fn;
  // modified user bytecode (protected by check_fn's guards)
  PyCodeObject* code;
  // on a cache miss, linked list of next thing to try
  struct cache_entry* next;
} CacheEntry;

void ignored(void* obj) {}

PyObject* get_cache_list(PyObject* self, PyObject* args) {
    PyObject* my_object;
    if (!PyArg_ParseTuple(args, "O", &my_object)) {
        return NULL;
    }
    PyCodeObject* code = (PyCodeObject*)my_object;

    size_t cache_entry_extra_index = _PyEval_RequestCodeExtraIndex(ignored);
    CacheEntry* current_node = NULL;
    _PyCode_GetExtra((PyObject*)code, cache_entry_extra_index, (void**)&current_node);

    // current_node = (CacheEntry*)code->co_extra;
    if(current_node == NULL)
    {
    return NULL;
    }

    PyObject* outer_list = PyList_New(0);
    if (!outer_list) {
        return NULL;  // Return NULL if failed to create list
    }
    while (current_node != NULL) {
        // Creating a new Python list for the check_fn and code of current CacheEntry
        PyObject* inner_list = PyList_New(0);
        if (!inner_list) {
            Py_DECREF(outer_list);  // Clean up if failed to create list
            return NULL;
        }

        // Add the check_fn and code to the inner list
        if (PyList_Append(inner_list, current_node->check_fn) < 0) {
            Py_DECREF(outer_list);
            Py_DECREF(inner_list);  // Clean up if failed to append
            return NULL;
        }
        if (PyList_Append(inner_list, (PyObject*)current_node->code) < 0) {
            Py_DECREF(outer_list);
            Py_DECREF(inner_list);  // Clean up if failed to append
            return NULL;
        }

        // Add the inner list to the outer list
        if (PyList_Append(outer_list, inner_list) < 0) {
            Py_DECREF(outer_list);
            Py_DECREF(inner_list);  // Clean up if failed to append
            return NULL;
        }

        // Move to the next node in the linked list
        current_node = current_node->next;
    }
    // Return the outer list
    return outer_list;
}

// Method list
static PyMethodDef DynamoInspectMethods[] = {
    {"get_guards_and_code", get_cache_list, METH_VARARGS,
     "Convert linked list of cache entries into a nested Python list."},
    {NULL, NULL, 0, NULL}  // Sentinel
};

// Module definition
static struct PyModuleDef dynamo_inspect_module = {
    PyModuleDef_HEAD_INIT,
    "dynamo_inspect",   // name of module
    NULL,               // module documentation, may be NULL
    -1,                 // size of per-interpreter state of the module, or -1 if the module keeps state in global variables.
    DynamoInspectMethods
};

// Module initialization function
PyMODINIT_FUNC PyInit_dynamo_inspect(void) {
    return PyModule_Create(&dynamo_inspect_module);
}

My Python test code:

import torch

def f(x):
    return x + 1

opt_f = torch.compile(f)

opt_f(torch.randn(5, 5, 5))

import dynamo_inspect
output = dynamo_inspect.get_guards_and_code(opt_f.__code__)

I’m trying to have a deep dive into torchdynamo, and want to extract the cache entry out of compiled function. It is really tough, though.

The PyCodeObject* there contains the modified Python bytecode (inside a code object), which internally will call whatever the backend compiler returned. The result of the backend compiler is stored as a global variable like "__compiled_fn0" that you should see if you call globals(). The captured graph is not stored after compile, though you could easily write a custom backend to store it.

If you are after the captured FX graph, the easiest way to get that would be:

captured_graphs = []

def my_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
     captured_graphs.append(gm)
     # call the default torch.compile backend:
     return torch._inductor.compile(gm, example_inputs)

opt_f = torch.compile(f, backend=my_backend)
opt_f(torch.randn(5, 5, 5))

print(captured_graphs)

That code compiles because it is C not C++, which has more relaxed pointer conversions. There are some internal bits of the CPython source code which don’t work from C++.

One problem with your code is you are calling _PyEval_RequestCodeExtraIndex(), this which allocates a brand new “index” different from the one dynamo uses. Each time one calls _PyEval_RequestCodeExtraIndex() it returns a new index. So dynamo might be using index 0, while your code would be reading index 1,2,3,… (you allocate a new index on every call). These indices are dynamically allocated, so would need to add an API to PyTorch to expose it.

If you want to go further down that road, I’d suggest modifying the PyTorch source code to add Python bindings for the CacheEntry object and an API to retrieve it.

1 Like

That’s very detailed and helpful explanations!

Now I know that compiled functions are __compiled_fn_xxx in globals(), and I also see the __resume_at_xxx for graph breaks.

After I use the C Extension rather than the C++ Extension, compilation goes smoothly. And the explanation of _PyEval_RequestCodeExtraIndex is great! I used to wrongly think code->co_extra stores additional data directly. Now I understand that it can store multiple data entries, via the use of index.

The indices used by pytorch (cache_entry_extra_index and dynamic_frame_state_extra_index) are private and cannot be accessed elsewhere. Would you like me to submit a pull request to add private APIs _debug_get_cache_entry for debugging usage? I managed to extract the guards and modified bytecode from compiled function, which I find really useful to understand and check the captured code of torchdynamo! Exposing the API in python makes it much more easier to inspect and manipulate for in-depth development. Of course, I flaged the API by _debug to indicate this is for debug only.

One more detailed question:

The signature of _PyEval_RequestCodeExtraIndex indicates it is a globla index, while I expect that each codeobject can have their own index, since each codeobject has a co_extra field, per specified by PEP 523 – Adding a frame evaluation API to CPython | peps.python.org .

The Python C-API is poorly documented! I googled and asked ChatGPT, but neither give me any helpful information. Thank you, Jason!

Sure, we would welcome the contribution!

It is indeed global, and I think changed a bit after PEP523. It is intended for JIT compilers (like TorchDynamo) that will be processing every PyCodeObject. So when you allocate an index you get your own slot on every single PyCodeObject. You would typically just call _PyEval_RequestCodeExtraIndex once at startup then use it for the duration of the program.

I believe co_extra (which is just 1 pointer) under the hood will contain a pointer to a dynamically allocated data structure that will contain room for every CodeExtraIndex. The set/get extra APIs are querying (and if needed allocating/growing) that data structure.

You are a wonderful expert in low-level details in Python! I have learned a lot from talking with you.

1 Like