I’m trying to get a clear mental model for torch.compile
to deal with stateful objects.
Traditionally, JIT compilers such as numba only compile for arrays and several Python basic data structures. It is easy to figure out when they re-compile: just check the metadata of input, e.g. dtype, shape, type
.
In the case of torch.compile
, the situation is quite complicated. The promise of torch.compile
is essentially it can work for any function f(*args, **kwargs)
with arbitary input, because it can understand Python bytecode. This makes the re-compile condition extremely difficult to figure out. Currently, it is kind of a black box.
The re-compile condition becomes even more difficult to figure out, when some arguments are stateful, e.g. dict
, list
, and notably, nn.Module
. Recently, I encountered such a problem, and had to use a global flag export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
to achieve what I want.
That’s why I’m asking. Does torch.compile
have a mental model for the re-compile condition (or essentially the guard system)? e.g. what does it guard on? object type? object id? object content (for stateful objects)? In terms of stateful objects, how can we guard on the content though? We know Python is very flexible and dynamic, the content of a stateful object can change quite dramatically.