Understanding the lifecycle of compiled code cache entry

I hit a corner case, where the compiled code cache entry is gone after function returns.

So I want to ask a general question here, when can we expect a compiled code cache entry will be removed? And how is this achieved? Python GC is notorious for no guarantee on the order of object destruction.

Because compilation is expensive, I want to keep as many compiled cache entries as possible. Understanding this would help me avoid accidentally releasing compiled cache entries.

For example, given the code in the issue:

import torch
from torch import nn

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(128, 128)

    def forward(self, x):
        out = self.linear(x) + 1
        return out

code = Model.forward.__code__

def bytecode_hook(old, new):
    assert old is code

torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)

def main():
    model = Model()
    model = torch.compile(model=model)
    data = torch.randn(64, 128)
    loss = model(data)

if __name__ == '__main__':
    from torch._dynamo.eval_frame import _debug_get_cache_entry_list

    # cannot find cache entry
    main()
    print(_debug_get_cache_entry_list(Model.forward.__code__))

    # can find cache entry
    model = Model()
    model = torch.compile(model=model)
    data = torch.randn(64, 128)
    loss = model(data)

    print(_debug_get_cache_entry_list(Model.forward.__code__))

I think it should be fine to keep the code cache entry, even after the function returns.

cc @jansel @anijain2305

The compile code may have specialized on id(model), so when model is deleted the code entry becomes invalid. Specifically a weakref hook will remove the code entry when model is deleted.

makes sense, so it is nn module specialization again :frowning:

Here are the guards I captured, they are sooooo verbose. every training attribute of nn module is checked, even if they do not matter in the computation, e.g. nn.SiLU().training .


# Note: the following variables are used inside the guard function.
___check_global_state = '''<built-in method check of torch._C._dynamo.guards.GlobalStateGuard object at 0x16827e690>'''
___check_obj_id = '''<built-in function check_obj_id>'''
___check_tensors = '''<built-in method check of torch._C._dynamo.guards.TensorGuards object at 0x1680fca90>'''
tensor_check_names = '''["L['x']"]'''
utils_device = '''<module 'torch.utils._device' from '/Users/youkaichao/anaconda3/envs/py310/lib/python3.10/site-packages/torch/utils/_device.py'>'''
def __guard_0_for_torch_dynamo_resume_in_forward_at_23(L, G, **___kwargs_ignored):
    return (___check_global_state()) \
        and (hasattr(L['x'], '_dynamo_dynamic_indices') == False) \
        and (___check_obj_id(L['self'], 4396891824)) \
        and (___check_obj_id(L['self'].training, 4380193976)) \
        and (utils_device.CURRENT_DEVICE == None) \
        and (___check_obj_id(L['self'].act2, 4396887840)) \
        and (___check_obj_id(L['self'].act2.training, 4380193976)) \
        and (___check_obj_id(L['self'].drop, 4396887696)) \
        and (___check_obj_id(L['self'].drop.training, 4380193976)) \
        and (___check_obj_id(L['self'].conv2, 4396887600)) \
        and (___check_obj_id(L['self'].conv2.training, 4380193976)) \
        and (___check_obj_id(L['self'].norm2, 4396890048)) \
        and (___check_obj_id(L['self'].norm2.training, 4380193976)) \
        and (___check_tensors(L['x'], tensor_check_names=tensor_check_names))

# Note: please refer to the graph code in __compiled_fn_5*.py.
# Captured Graph: Dynamo generated graph (debuggable when using eager backend).
# Joint graph: joint forward+backward graph from aot autograd.
# Forward graph: forward graph from aot autograd (debuggable when using aot_eager backend).
# Backward graph: backward graph from aot autograd (debuggable when using aot_eager backend).
# AFTER XXX: graph processed by inductor (not debuggable).
def __compiled_fn_5(*args, **kwargs):
    pass

def __transformed_code_0_for_torch_dynamo_resume_in_forward_at_23(___stack0, self, x):
    out = None # this line helps Python to generate bytecode with at least the same number of local variables as the original function
    __temp_6, = __compiled_fn_5(x)
    return __temp_6


# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.
def __resume_at_50_3(___stack0, self, x):
    x = self.norm2(x)
    x = self.act2(x)
    x = self.drop(x)
    x = self.conv2(x)
    out = (x * x).mean()
    return out

def transformed___resume_at_50_3(___stack0, self, x):
    __local_dict = {"___stack0": ___stack0, "self": self, "x": x}
    __global_dict = globals()
    if __guard_0_for_torch_dynamo_resume_in_forward_at_23(__local_dict, __global_dict):
        return __transformed_code_0_for_torch_dynamo_resume_in_forward_at_23(___stack0, self, x)
    # Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.
    return __resume_at_50_3(___stack0, self, x)

#============ end of __resume_at_50_3 ============#

# Note: the following variables are used inside the guard function.
___check_global_state = '''<built-in method check of torch._C._dynamo.guards.GlobalStateGuard object at 0x1669d2ab0>'''
___check_obj_id = '''<built-in function check_obj_id>'''
___check_tensors = '''<built-in method check of torch._C._dynamo.guards.TensorGuards object at 0x1680fc7b0>'''
tensor_check_names = '''["L['x']"]'''
utils_device = '''<module 'torch.utils._device' from '/Users/youkaichao/anaconda3/envs/py310/lib/python3.10/site-packages/torch/utils/_device.py'>'''
def __guard_0_for_forward(L, G, **___kwargs_ignored):
    return (___check_global_state()) \
        and (hasattr(L['x'], '_dynamo_dynamic_indices') == False) \
        and (___check_obj_id(L['self'], 4396891824)) \
        and (___check_obj_id(L['self'].training, 4380193976)) \
        and (utils_device.CURRENT_DEVICE == None) \
        and (___check_obj_id(G['__builtins_dict___1']['print'], 4381267088)) \
        and (___check_obj_id(L['self'].act1, 4396891632)) \
        and (___check_obj_id(L['self'].act1.training, 4380193976)) \
        and (___check_obj_id(L['self'].pool, 4396891728)) \
        and (___check_obj_id(L['self'].pool.training, 4380193976)) \
        and (___check_obj_id(L['self'].conv1, 4396891680)) \
        and (___check_obj_id(L['self'].conv1.training, 4380193976)) \
        and (___check_obj_id(L['self'].norm1, 4396891776)) \
        and (___check_obj_id(L['self'].norm1.training, 4380193976)) \
        and (___check_tensors(L['x'], tensor_check_names=tensor_check_names))

# Note: please refer to the graph code in __compiled_fn_2*.py.
# Captured Graph: Dynamo generated graph (debuggable when using eager backend).
# Joint graph: joint forward+backward graph from aot autograd.
# Forward graph: forward graph from aot autograd (debuggable when using aot_eager backend).
# Backward graph: backward graph from aot autograd (debuggable when using aot_eager backend).
# AFTER XXX: graph processed by inductor (not debuggable).
def __compiled_fn_2(*args, **kwargs):
    pass

def __transformed_code_0_for_forward(self, x):
    out = None # this line helps Python to generate bytecode with at least the same number of local variables as the original function
    graph_out_0 = __compiled_fn_2(x)
    x = graph_out_0[0]
    return __resume_at_50_3(__builtins_dict___1['print'](
        'will trigger graph break, x.shape:', __import_torch.Size((64, 256, 32,
        32))), self, x)


# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.
def forward(self, x):
    x = self.norm1(x)
    x = self.act1(x)
    x = self.conv1(x)
    x = self.pool(x)
    print('will trigger graph break, x.shape:', x.shape)
    x = self.norm2(x)
    x = self.act2(x)
    x = self.drop(x)
    x = self.conv2(x)
    out = (x * x).mean()
    return out

def transformed_forward(self, x):
    __local_dict = {"self": self, "x": x}
    __global_dict = globals()
    if __guard_0_for_forward(__local_dict, __global_dict):
        return __transformed_code_0_for_forward(self, x)
    # Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.
    return forward(self, x)

#============ end of forward ============#