A minimal working example of standalone usage for dynamo eval_frame

Dynamo uses PEP 523 to hook into CPython interpreter, but it doesn’t allow arbitary hook functions. Instead, it modifies the frame evaluation API into a just-in-time compiler API. Here is a minimal working example of how to use dynamo for standalone usage:

import inspect
from typing import Dict, Optional

from torch._dynamo.bytecode_transformation import is_generator
from torch._dynamo.eval_frame import set_eval_frame

from torch._dynamo.types import (
    CacheEntry,
    DynamoFrameType,
    FrameState,
    GuardedCode,
    GuardFn,
)


def full_function_name(frame):
    # Get the function name
    function_name = frame.f_code.co_name

    # Check if we are inside a class
    try:
        class_name = frame.f_locals["self"].__class__.__name__
        function_name = class_name + "." + function_name
    except KeyError:
        pass

    # Get the module name
    module = inspect.getmodule(frame)
    module_name = module.__name__ if module is not None else "empty module"

    return module_name + "." + function_name


class Always(GuardFn):
    def __init__(self, hit=True):
        self.hit = hit

    def __call__(self, f_locals: Dict) -> bool:
        print(f"guard called with {len(f_locals)} args.")
        return self.hit


class MyCallback:
    def __init__(self, skip: bool, hit: bool = False):
        self.count = 0
        self.skip = skip
        self.hit = hit

    def __call__(
        self,
        frame: DynamoFrameType,
        cache: Optional[CacheEntry],
        frame_state: FrameState,
    ):
        self.count += 1
        cache_len = 0
        while cache is not None:
            cache_len += 1
            cache = cache.next
        print(f"calling the {self.count}-th function: {full_function_name(frame)}")
        print(f"the function has {cache_len} cache entries now.")

        if "cache_len" in frame_state:
            cache_len = frame_state["cache_len"]
            print(f"number of caches in the last call: {cache_len}")
        frame_state["cache_len"] = cache_len

        if self.skip or is_generator(frame.f_code):
            return None
        return GuardedCode(code=frame.f_code, check_fn=Always(hit=self.hit))


def f(x):
    return x + 1


def g(x):
    return x + 2


def eval_frame_mwe(skip, hit=False):
    callback = MyCallback(skip=skip, hit=hit)
    prior = set_eval_frame(callback)

    x = 1
    x = f(x)
    x = g(x)
    x = f(x)
    x = g(x)
    x = f(x)

    set_eval_frame(prior)


def test_always_skip_callback():
    eval_frame_mwe(skip=True)

def test_always_hit_callback():
    eval_frame_mwe(skip=False, hit=True)

def test_always_miss_callback():
    eval_frame_mwe(skip=False, hit=False)

The core API is to provide a callback function to the set_eval_frame:

  • If callback is None, or the code is marked as SKIP, dynamo does nothing. Just plain frame evaluation as usual.
  • else:
    • Pull out code caches from the frame’s code object, look up existing cache one by one, and run the first one whose guard function returns true (a cache hit).
    • If not any cache entry is hit, this is a cache miss. Then it depends on the value of callback:
      • If callback is False, dynamo runs the original bytecode.
      • If callback is a function, dynamo calls the function to create a new cache entry, and run that cache entry.

The callback function currently accepts a frame, the head of cache entry (may be None if no cache entry exists), and a dict named frame_state to store additional data. It has two choices:

  • return None to indicate the code can be skipped the next time.
  • return a GuardedCode instance, with a new bytecode object and a check function (guard) to check if the code can be run.

The minimal working example shows a hack to use JIT to count the occurrence of functions with frames:

  • test_always_skip_callback will capture all the frames at the first occurrence. Example output:
calling the 1-th function: __main__.f
the function has 0 cache entries now.
calling the 2-th function: __main__.g
the function has 0 cache entries now.
  • test_always_hit_callback will capture all the frames at every occurrence. Starting from the second occurrence, the guard function runs first, and since it always returns True, callback is not called and dynamo just run the bytecode. Example output:
calling the 1-th function: __main__.f
the function has 0 cache entries now.
calling the 2-th function: __main__.g
the function has 0 cache entries now.
guard called with 1 args.
guard called with 1 args.
guard called with 1 args.
  • test_always_miss_callback will capture all the frames at every occurrence. For each occurrence, it first runs all the guards, and since they all return False, it will create a new cache entry every time. It is not useful (and very inefficient), but it is good to understand dynamo. Example output:
calling the 1-th function: __main__.f
the function has 0 cache entries now.
calling the 2-th function: __main__.g
the function has 0 cache entries now.
guard called with 1 args.
calling the 3-th function: __main__.f
the function has 1 cache entries now.
number of caches in the last call: 0
guard called with 1 args.
calling the 4-th function: __main__.g
the function has 1 cache entries now.
number of caches in the last call: 0
guard called with 1 args.
guard called with 1 args.
calling the 5-th function: __main__.f
the function has 2 cache entries now.
number of caches in the last call: 0

Now we can play around some real function in eval_frame_mwe, to see what functions are called:

def eval_frame_mwe(skip, hit=False):
    callback = MyCallback(skip=skip, hit=hit)
    from torchvision.models.resnet import resnet50
    prior = set_eval_frame(callback)

    net = resnet50()

    set_eval_frame(prior)

test_always_skip_callback()

Warning: the output might be very long!

Note the signature of callback and guards might change in the future. It’s better to check C code torch/csrc/dynamo/eval_frame.c should you encounter any problems.

6 Likes

thanks…
I am also debugging through whole compiler stack, this helps…

@mayurnewase Glad it helps :slight_smile: