What is the mental model for torch.compile to deal with stateful objects in general?

I’m trying to get a clear mental model for torch.compile to deal with stateful objects.

Traditionally, JIT compilers such as numba only compile for arrays and several Python basic data structures. It is easy to figure out when they re-compile: just check the metadata of input, e.g. dtype, shape, type.

In the case of torch.compile, the situation is quite complicated. The promise of torch.compile is essentially it can work for any function f(*args, **kwargs) with arbitary input, because it can understand Python bytecode. This makes the re-compile condition extremely difficult to figure out. Currently, it is kind of a black box.

The re-compile condition becomes even more difficult to figure out, when some arguments are stateful, e.g. dict, list, and notably, nn.Module . Recently, I encountered such a problem, and had to use a global flag export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 to achieve what I want.

That’s why I’m asking. Does torch.compile have a mental model for the re-compile condition (or essentially the guard system)? e.g. what does it guard on? object type? object id? object content (for stateful objects)? In terms of stateful objects, how can we guard on the content though? We know Python is very flexible and dynamic, the content of a stateful object can change quite dramatically.

cc @jansel @anijain2305

The handling here is specific to torch.nn.Module() instances. Depending on the config, torch.compile() might specialize on the id() of a nn.Module. This allows additional optimizations (freezing, copy-free-cudagraphs, etc) since the module parameters are known at a compile time.

While I know this specific issue is related with nn.Module instances, I’m trying to find a general mental model for how Dynamo places guards for stateful objects.

Dynamo guards on nn.Module id()s unless you mark the module dynamic (which happens automatically if you mutate self) or disable that behavior by config.