Dynamo uses PEP 523 to hook into CPython interpreter, but it doesn’t allow arbitary hook functions. Instead, it modifies the frame evaluation API into a just-in-time compiler API. Here is a minimal working example of how to use dynamo for standalone usage:
import inspect
from typing import Dict, Optional
from torch._dynamo.bytecode_transformation import is_generator
from torch._dynamo.eval_frame import set_eval_frame
from torch._dynamo.types import (
CacheEntry,
DynamoFrameType,
FrameState,
GuardedCode,
GuardFn,
)
def full_function_name(frame):
# Get the function name
function_name = frame.f_code.co_name
# Check if we are inside a class
try:
class_name = frame.f_locals["self"].__class__.__name__
function_name = class_name + "." + function_name
except KeyError:
pass
# Get the module name
module = inspect.getmodule(frame)
module_name = module.__name__ if module is not None else "empty module"
return module_name + "." + function_name
class Always(GuardFn):
def __init__(self, hit=True):
self.hit = hit
def __call__(self, f_locals: Dict) -> bool:
print(f"guard called with {len(f_locals)} args.")
return self.hit
class MyCallback:
def __init__(self, skip: bool, hit: bool = False):
self.count = 0
self.skip = skip
self.hit = hit
def __call__(
self,
frame: DynamoFrameType,
cache: Optional[CacheEntry],
frame_state: FrameState,
):
self.count += 1
cache_len = 0
while cache is not None:
cache_len += 1
cache = cache.next
print(f"calling the {self.count}-th function: {full_function_name(frame)}")
print(f"the function has {cache_len} cache entries now.")
if "cache_len" in frame_state:
cache_len = frame_state["cache_len"]
print(f"number of caches in the last call: {cache_len}")
frame_state["cache_len"] = cache_len
if self.skip or is_generator(frame.f_code):
return None
return GuardedCode(code=frame.f_code, check_fn=Always(hit=self.hit))
def f(x):
return x + 1
def g(x):
return x + 2
def eval_frame_mwe(skip, hit=False):
callback = MyCallback(skip=skip, hit=hit)
prior = set_eval_frame(callback)
x = 1
x = f(x)
x = g(x)
x = f(x)
x = g(x)
x = f(x)
set_eval_frame(prior)
def test_always_skip_callback():
eval_frame_mwe(skip=True)
def test_always_hit_callback():
eval_frame_mwe(skip=False, hit=True)
def test_always_miss_callback():
eval_frame_mwe(skip=False, hit=False)
The core API is to provide a callback
function to the set_eval_frame
:
- If
callback
isNone
, or the code is marked asSKIP
, dynamo does nothing. Just plain frame evaluation as usual. - else:
- Pull out code caches from the frame’s code object, look up existing cache one by one, and run the first one whose guard function returns true (a cache hit).
- If not any cache entry is hit, this is a cache miss. Then it depends on the value of
callback
:- If
callback
isFalse
, dynamo runs the original bytecode. - If
callback
is a function, dynamo calls the function to create a new cache entry, and run that cache entry.
- If
The callback
function currently accepts a frame, the head of cache entry (may be None
if no cache entry exists), and a dict named frame_state
to store additional data. It has two choices:
- return
None
to indicate the code can be skipped the next time. - return a
GuardedCode
instance, with a new bytecode object and a check function (guard) to check if the code can be run.
The minimal working example shows a hack to use JIT to count the occurrence of functions with frames:
test_always_skip_callback
will capture all the frames at the first occurrence. Example output:
calling the 1-th function: __main__.f
the function has 0 cache entries now.
calling the 2-th function: __main__.g
the function has 0 cache entries now.
test_always_hit_callback
will capture all the frames at every occurrence. Starting from the second occurrence, the guard function runs first, and since it always returnsTrue
,callback
is not called and dynamo just run the bytecode. Example output:
calling the 1-th function: __main__.f
the function has 0 cache entries now.
calling the 2-th function: __main__.g
the function has 0 cache entries now.
guard called with 1 args.
guard called with 1 args.
guard called with 1 args.
test_always_miss_callback
will capture all the frames at every occurrence. For each occurrence, it first runs all the guards, and since they all returnFalse
, it will create a new cache entry every time. It is not useful (and very inefficient), but it is good to understand dynamo. Example output:
calling the 1-th function: __main__.f
the function has 0 cache entries now.
calling the 2-th function: __main__.g
the function has 0 cache entries now.
guard called with 1 args.
calling the 3-th function: __main__.f
the function has 1 cache entries now.
number of caches in the last call: 0
guard called with 1 args.
calling the 4-th function: __main__.g
the function has 1 cache entries now.
number of caches in the last call: 0
guard called with 1 args.
guard called with 1 args.
calling the 5-th function: __main__.f
the function has 2 cache entries now.
number of caches in the last call: 0
Now we can play around some real function in eval_frame_mwe
, to see what functions are called:
def eval_frame_mwe(skip, hit=False):
callback = MyCallback(skip=skip, hit=hit)
from torchvision.models.resnet import resnet50
prior = set_eval_frame(callback)
net = resnet50()
set_eval_frame(prior)
test_always_skip_callback()
Warning: the output might be very long!
Note the signature of callback
and guards
might change in the future. It’s better to check C code torch/csrc/dynamo/eval_frame.c
should you encounter any problems.