[Fatal Bug] changed nn.Module.training does not trigger recompilation

:bug: Describe the bug

Here is the sample code:

import torch
from torch import nn

from torch import _dynamo
from torch._functorch.aot_autograd import aot_module_simplified
import functorch
from functorch.compile import make_boxed_func

def graph_processing_pytorch(gm, example_inputs):
    # graph transform (graph optimization) for the captured graph in pytorch
    print("captured graph in pytorch")
    gm.print_readable()

def graph_processing_aot_forward(gm, example_inputs):
    # graph transform (graph optimization) for the captured graph in aot autograd forward graph
    print("captured graph in aot autograd forward")
    gm.print_readable()

def graph_processing_aot_backward(gm, example_inputs):
    # graph transform (graph optimization) for the captured graph in aot autograd backward graph
    print("captured graph in aot autograd backward")
    gm.print_readable()

def forward_compiler(gm, example_inputs):
    graph_processing_aot_forward(gm, example_inputs)
    return make_boxed_func(gm.forward)

def backward_compiler(gm, example_inputs):
    graph_processing_aot_backward(gm, example_inputs)
    return make_boxed_func(gm.forward)

def custom_backend(gm, example_inputs):
    graph_processing_pytorch(gm, example_inputs)
    return aot_module_simplified(
        gm, example_inputs,
        fw_compiler=forward_compiler,
        bw_compiler=backward_compiler
    )

class BackboneModel(nn.Module):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.drop1 = nn.Dropout(0.5)
        self.bn1 = nn.BatchNorm2d(6)

    def forward(self, x):
        x = self.bn1(self.drop1(x))
        return x


model = BackboneModel()

opt_model = torch.compile(model, backend=custom_backend)

input = torch.randn(64, 6, 32, 32)

output1 = opt_model(input)

# calling .eval in the whole model works
# opt_model.eval()

# calling .eval in sub model does not trigger re-compilation
opt_model.drop1.eval()
opt_model.bn1.eval()

output2 = opt_model(input)
output3 = opt_model(input)

print((output2 - output3).abs().max().item()) # huge number, nondeterministic!

It seems only the .eval call in the whole model works. The training flag in sub-models are not respected. As a result, after I call eval for all the sub-models, the output is still nondeterministic.

Versions

pytorch 2.0.1

Note: I opened an issue at [Fatal Bug] changed nn.Module.training does not trigger recompilation. · Issue #105653 · pytorch/pytorch · GitHub . Recently I opened several issues, but got no response. Therefore, I come here for dev support :frowning: