Question about the necessity of some guards in Dynamo

When I compile some very simple function, like:

@torch.compile
def f(x):
    return x + 1

I see many guards, like:

def __guard_0_for_f(L):
    return (___guarded_code.valid) \
        and (___check_global_state()) \
        and (hasattr(L['x'], '_dynamo_dynamic_indices') == False) \
        and (utils_device.CURRENT_DEVICE == None) \
        and ((___skip_backend_check() or ___current_backend() == ___lookup_backend(5082072448))) \
        and (___check_tensors(L['x'], tensor_check_names=tensor_check_names))

Are they all necessary? It seems some are not.For example, what is the usage of looking up ___current_backend ? Do we even have the concept of current backend? When will ___guarded_code be invalid?

If I understand correctly, we should store the backend on the code object, and run that backend when we need to recompile.

The current backend check is if you pass different backend=... options to torch.compile. It is needed because the cache is shared between all backends.

___guarded_code.valid allows for “push” style invalidation. For example, when a weakref to a guarded object goes out of scope we need to invalidate the code attached to it.

Uh, that makes sense.

If I understand correctly:

backend check is used for:

opt_f1 = torch.compile(f, backend=backend1)
opt_f2 = torch.compile(f, backend=backend2)

___guarded_code.valid can be used for torch._dynamo.reset().

Right?

These senarios rarely happen in normal workflow, though.

Correct about the backend check.

___guarded_code.valid is used primarily for weakref invalidation. Suppose you have a guard that checks id(self)==123456 where 123456 is the memory address of some nn.Module instance, If the nn.Module located at 123456 were freed, another one could be allocated at the same address. So we install a hook (using weakref.ref) so that if that nn.Module is freed, we set valid=False.

That’s amazing. It is a masterpiece to manage these guards in a dynamic language like Python :rofl: