Debugging story: The case of the flaky Dynamo export tests

Keep asking why.

For months, Dynamo export tests have been intermittently failing with “AssertionError: whole graph export entails exactly one call”. Flaky tests that you can’t reproduce are pretty difficult to debug, since you can’t easily make a change and see if it fixes the problem or not. So when I noticed that my local pytest runs were reliably failing in a similar way (so reliably, that initially I thought that they were failing deterministically), I figured I should stop what I was doing and get to the bottom of the problem, because it seemed pretty unlikely anyone else would be able to figure it out otherwise.

When trying to debug a tricky problem, it pays to reduce experiment iteration time. In my case, it took a few minutes to run the full test suite: I wanted to trigger the flaky test to trigger within the first few runs. I initially wanted to do this by running all the tests in a fixed order, and then delta debugging removing running tests until the flakiness went away. But it turns out pytest doesn’t really have any built-in way to do this; stock unittest does, but in fact the tests stop failing when you don’t run with unittest (for those guessing at home, this is an important clue!) However, pytest does support boolean operators in the -k argument flag, e.g., you can say -k foo and bar to select tests that only contain foo and bar in their name. So I used this to successively pare down tests to something that could be run in a few seconds: PYTHONUNBUFFERED=1 pytest test/dynamo/test_dynamic_shapes.py -k 'Unspec and export and not dupes and not reorder' -v -x -s. Great.

Next, I wanted to better understand what exactly was going on when things failed. A good start is to turn on DEBUG logging and look through the logs. Unfortunately, turning on logging also makes the bug go away. So it’s back to good old fashioned print debugging. After a bit of triangulating, I finally determine that in the buggy case, we’re arriving at this block of code:

            on_enter()
            prior = set_eval_frame(callback)
            backend_ctx = backend_ctx_ctor()
            backend_ctx.__enter__()
            dynamic_ctx = enable_dynamic(self.dynamic, self.export)
            dynamic_ctx.__enter__()
            try:
                return fn(*args, **kwargs)
            finally:
                set_eval_frame(prior)
                dynamic_ctx.__exit__(None, None, None)
                backend_ctx.__exit__(None, None, None)

We successfully setup the Dynamo frame callback, we call fn and… the frame callback just doesn’t get called.

OK, so this sounds like something has gone wrong with the callback setup. Can we log that? Yes we can! torch/csrc/dynamo/eval_frame.c has an undocumented TORCHDYNAMO_DEBUG preprocessor macro that, if defined, turns on a bunch of debug logging that looks like this:

TRACE[_custom_eval_frame:650] begin __del__ /data/users/ezyang/a/pytorch/torch/multiprocessing/reductions.py 37 -1 0 0
TRACE[_custom_eval_frame:679] skip __del__
TRACE[_custom_eval_frame:650] begin _free_weak_ref /data/users/ezyang/a/pytorch/torch/storage.py 766 -1 0 0
TRACE[_custom_eval_frame:679] skip _free_weak_ref
TRACE[_custom_eval_frame:650] begin <genexpr> /data/users/ezyang/a/pytorch/torch/_dynamo/eval_frame.py 733 14 0 0
TRACE[set_eval_frame_py:831] python enabled=0 and is run_only=0

(There is in fact a very important clue in this log, though I didn’t figure it out until later.) I also added some more debug logs for when we modify the eval frame callback, to see if we were unsetting the custom frame callback and then forgetting to reset it. And indeed, this is precisely what was happening:

  if (result == NULL) {
    // internal exception, returning here will leak the exception into user code
    // this is useful for debugging -- but we dont want it to happen outside of
    // testing
    return NULL;

Compare this with the success case right below, which re-enables the custom callback:

  } else if (result != Py_None) {
    DEBUG_TRACE("create cache %s", name(frame));
    extra = create_cache_entry(extra, result);
    Py_DECREF(result);
    set_extra(frame->f_code, extra);
    // Re-enable custom behavior
    eval_frame_callback_set(callback);
    return eval_custom_code(tstate, frame, extra->code, throw_flag);

Modifying the internal exception callback to reset the Dynamo callback indeed “fixed” the test failures. But there were some unresolved mysteries. Why was the callback internal erroring? Why did this fail nondeterministically? Was it possible that we intentionally were keeping the callback disabled in case of error?

Keep asking why.

I decided I wanted to understand why the test was failing nondeterministically. What was different about this test versus another? Here was an example of a test that was succeeding:

TRACE[_custom_eval_frame:650] begin __init__ /home/ezyang/local/a/pytorch-env/lib/python3.9/contextlib.py 86 -1 0 0
TRACE[_custom_eval_frame:679] skip __init__
TRACE[_custom_eval_frame:650] begin __enter__ /home/ezyang/local/a/pytorch-env/lib/python3.9/contextlib.py 114 -1 0 0
TRACE[_custom_eval_frame:679] skip __enter__
TRACE[_custom_eval_frame:650] begin enable_dynamic /data/users/ezyang/a/pytorch/torch/_dynamo/eval_frame.py 162 -1 0 0
TRACE[_custom_eval_frame:679] skip enable_dynamic
TRACE[_custom_eval_frame:650] begin func /data/users/ezyang/a/pytorch/test/dynamo/test_export.py 83 -1 0 0
TRACE[_custom_eval_frame:752] create cache func

Each begin log line tells us whenever Dynamo’s custom frame handler is invoked. We see some context manager stuff, a call to enable_dynamic, and then a call to the actual function to be traced. Everything before the function is skipped, and the function itself creates a cache entry. This all makes sense, and we can line up each log entry with actual Python code that is executing after the eval frame is setup.

How about the failing case?

TRACE[_custom_eval_frame:650] begin enable_dynamic /data/users/ezyang/a/pytorch/torch/_dynamo/eval_frame.py 162 -1 0 0
TRACE[_custom_eval_frame:679] skip enable_dynamic
TRACE[_custom_eval_frame:650] begin __call__ /home/ezyang/local/a/pytorch-env/lib/python3.9/weakref.py 586 -1 0 0
TRACE[_custom_eval_frame:760] create skip __call__
TRACE[_custom_eval_frame:650] begin del_ten /data/users/ezyang/a/pytorch/torch/_subclasses/meta_utils.py 118 -1 0 0
TRACE[_custom_eval_frame:760] create skip del_ten
TRACE[_custom_eval_frame:650] begin __del__ /data/users/ezyang/a/pytorch/torch/multiprocessing/reductions.py 37 -1 0 0
TRACE[_custom_eval_frame:760] create skip __del__
TRACE[_custom_eval_frame:650] begin _free_weak_ref /data/users/ezyang/a/pytorch/torch/storage.py 766 -1 0 0
TRACE[_custom_eval_frame:760] create skip _free_weak_ref
TRACE[_custom_eval_frame:650] begin <lambda> /data/users/ezyang/a/pytorch/torch/_dynamo/utils.py 329 -1 0 0
TRACE[_custom_eval_frame:760] create skip <lambda>
TRACE[_custom_eval_frame:650] begin _remove_id /data/users/ezyang/a/pytorch/torch/_dynamo/utils.py 332 -1 0 0
TRACE[_custom_eval_frame:760] create skip _remove_id
TRACE[_custom_eval_frame:650] begin <lambda> /data/users/ezyang/a/pytorch/torch/_dynamo/utils.py 329 -1 0 0
TRACE[_custom_eval_frame:679] skip <lambda>
TRACE[_custom_eval_frame:650] begin _remove_id /data/users/ezyang/a/pytorch/torch/_dynamo/utils.py 332 -1 0 0
TRACE[_custom_eval_frame:679] skip _remove_id
TRACE[_custom_eval_frame:650] begin __del__ /data/users/ezyang/a/pytorch/torch/multiprocessing/reductions.py 37 -1 0 0
TRACE[_custom_eval_frame:679] skip __del__
TRACE[_custom_eval_frame:650] begin _free_weak_ref /data/users/ezyang/a/pytorch/torch/storage.py 766 -1 0 0
TRACE[_custom_eval_frame:679] skip _free_weak_ref
TRACE[_custom_eval_frame:650] begin <genexpr> /data/users/ezyang/a/pytorch/torch/_dynamo/eval_frame.py 733 14 0 0
TRACE[set_eval_frame_py:831] python enabled=0 and is run_only=0

We see some familiar faces: this trace still calls enable_dynamic. But then, there are a bunch of random frames that don’t seem to correspond at all to the Python code executing. Where did it come from? There are some clues: there are a lot of clues to __del__, which is responsible for object finalization. And this <genexpr>, where the heck did that come from?

image

It’s the GC! I quickly check what happens if I disable GC with pytest-gc, and indeed, this also solves the problem. Non-determinism mystery solved: the failure happens nondeterministically because when GC triggers is nondeterministic; and you have to get unlucky enough to trigger a GC right after Dynamo’s frame handler is installed, but before we actually get to the actual frame in question. We can also confirm this by installing the frame handler after entering the context managers. This also explains why the error is easier to reproduce in pytest: pytest has some reference cycles in its implementations, and causes the GC to run a lot more frequently.

So, we have multiple ways we could fix the problem. It also seems like maybe it’s a good idea not to run Dynamo during GC? cpython - How to detect if frame evaluation is occurring during GC in Python - Stack Overflow Though, I did some CPython reading and it seems to not be possible. But does this mean we’re done debugging?

Keep asking why.

Why is it a problem for GC to run while the Dynamo frame handler is installed? Although it’s weird, it seems like the frame handler should still get called as normal. But instead, the callback is failing with an internal error. But how can it be failing with an internal error if we never actually call the callback?

We can force CPython to print out what a raised exception is, and then restore the exception so that it keeps propagating as normal, using the following trick:

PyObject *ptype, *pvalue, *ptraceback;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
Py_XINCREF(ptype);
Py_XINCREF(pvalue);
Py_XINCREF(ptraceback);
PyErr_Restore(ptype, pvalue, ptraceback);
PyErr_Print();
PyErr_Restore(ptype, pvalue, ptraceback);

(It’s important to use Py_XINCREF, since ptype/pvalue/ptraceback could be NULL).

The printed backtrace shows that we’re failing on the very first line of the callback handler. I can’t even print: before the print runs, I just immediately fail out. Could it be… that there is already a Python error active when I enter the eval frame? A quick check of PyErr_Occurred() before the callback confirms this!

Why is there a Python error active? I’m not really sure, but I’m reminded of a little bit of CPython trivia: when a generator is exhausted, it finishes by raising a GeneratorExit exception. So maybe we’re in the process of unwinding a generator when this happens?

I submit my first patch, and go to bed.

diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index 4a5ca5b6abdb7..1e588f5dcdfd6 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -646,6 +646,17 @@ static PyObject* _custom_eval_frame(
       frame->f_lasti,
       frame->f_iblock,
       frame->f_executing);
+
+  // In obscure situations, we can enter the eval frame with an exception
+  // already set (the most common situation this when we hit a generator
+  // expression which is exiting with GeneratorExit).  In this case, there
+  // isn't really any chance that Dynamo will be able to successfully handle
+  // it.  Immediately propagate it out.
+  if (PyErr_Occurred() != NULL) {
+    DEBUG_TRACE("propagate error %s", name(frame));
+    return NULL;
+  }
+
   CacheEntry* extra = get_extra(frame->f_code);
   if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) {
     DEBUG_TRACE("skip %s", name(frame));

In the morning, all of the regular test suite is passing… but some tests running under the ‘dynamo’ configuration are failing. Sometimes, fixing bugs in Dynamo can cause previously passing tests to start failing because of latent bugs. I could have skipped those tests and moved on with my life.

Keep asking why.

I resolve that I should at least convince myself that the new failures are not my fault. Fortunately, the failures are deterministic. Oddly, the test in question is xfail’ed, but now we are somehow skipping xfail handling and still erroring with the exception. I use pdb to step through the stack unwinding and confirm that, yes, somehow we’re just not hitting the finally handler. Fortunately, it turns out just throwing an exception and having a non-trivial handler is sufficient to repro the problem.

import torch
import contextlib
import torch._dynamo

import logging
torch._dynamo.config.log_level = logging.DEBUG
torch._dynamo.config.output_code = True

@contextlib.contextmanager
def ctx():
    try:
        yield
    except RuntimeError:
        print("out")

@torch._dynamo.optimize("eager")
def f():
    with ctx():
        h()

def h():
    raise RuntimeError("boof")

f()

I carefully inspect Dynamo’s rewritten bytecode and confirm that, yes, the bytecode is still properly setting up the context manager. It seems like there’s something wrong with my patch. I can think of a few possible alternate ways to write it that might work…

Keep asking why.

Alban, who at this point I’ve roped into the sordid saga, wants to understand why there is an error active at frame entry. Alban reads through the source code for _PyEval_EvalFrameDefault and notices two things:

  1. You’re NOT supposed to call eval frame with the error set, there’s a debug assert guarding against this! cpython/ceval.c at 53dceb53ade15587b9cfd30c0a0942232517dee9 · python/cpython · GitHub
  2. BUT generator exit is special cased go to exit unwinding when throwflag is true. cpython/ceval.c at 53dceb53ade15587b9cfd30c0a0942232517dee9 · python/cpython · GitHub

This is consistent with everything we’ve seen so far: the naughty frames are invariably <genexpr> frames, and I do a quick test and confirm that, indeed, throwflag is true when we get here. It also seems clear that returning NULL is different from unwinding, since unwinding causes a bunch of extra finalization code to run. This leads us to our final patch:

diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index 4a5ca5b6abdb7..39a8226704919 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -646,6 +646,32 @@ static PyObject* _custom_eval_frame(
       frame->f_lasti,
       frame->f_iblock,
       frame->f_executing);
+
+  if (throw_flag) {
+    // When unwinding generators, eval frame is called with throw_flag ==
+    // true.  Frame evaluation is supposed to continue unwinding by propagating
+    // the exception.  Dynamo doesn't really know how to do this, nor does it
+    // really want to do this, because there's unlikely any code to capture
+    // (you're going to immediately quit out of the frame, perhaps running
+    // some unwinding logic along the way).  So we just run the default
+    // handler in this case.
+    //
+    // NB: A previous version of this patch returned NULL.  This is wrong,
+    // because returning NULL is *different* from unwinding an exception.
+    // In particular, you will not execute things like context manager
+    // __exit__ if you just return NULL.
+    //
+    // NB: It's /conceivable/ that you might want to actually still call the
+    // Dynamo callback when throw_flag == TRUE, to give Dynamo a chance to
+    // do any stack unwinding code.  But this is not really useful because
+    // (1) Dynamo doesn't actually know how to do stack unwinding, so it would
+    // immediately skip the frame, and (2) even if it did, this would only
+    // be profitable if there was tensor code in the unwinding code.  Seems
+    // unlikely.
+    DEBUG_TRACE("throw %s", name(frame));
+    return eval_frame_default(tstate, frame, throw_flag);
+  }
+
   CacheEntry* extra = get_extra(frame->f_code);
   if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) {
     DEBUG_TRACE("skip %s", name(frame));

This is the end of our debugging story. No stone has been left unturned: we have successfully explained every observed behavior in the original problem. By continually asking why, we have avoided landing an incorrect fix which would have suppressed the initial problem, but opened us up to an even more subtle and more rarely triggered bug. The moral of the story?

Keep asking why.

Final PR: Fix flaky Dynamo export tests by ezyang · Pull Request #96488 · pytorch/pytorch · GitHub

11 Likes

In the spirit of asking why – why aren’t you (or our CI?) running a debug version of Python?

I’m not running debug Python because it’s a pain to setup (I have to spin up an entirely new Python environment and build Python from source.) As for CI, we probably should have a version of CI that runs with debug Python.

Challenge is total execution size of PyTorch testsute, which takes 3-5x longer to finish for debug build. But it could be done on demand/periodically. Or one can select a subset of tests to run.

I don’t think that is true? Greg mentioned a debug version of CPython, not PyTorch. I am running a debug version of CPython locally and it is not noticeably slower from what I can see.