Backward module does not contains weight's gradients calculation

I am trying to analyze the training procedure by utilizing torch.fx and aot_autograd.
However, with pytorch 2.0, I cannot capture the gradient calculation of weights.

PyTorch 1.11.0 + functorch 0.1.0

As you can see in the below code, mm_4 and mm_2 are calculating the gradients of weights.

import torch
from torch.nn import *
class FxModule(torch.nn.Module):
    def __init__(self):
        self.register_buffer('_tensor_constant0', torch.empty([1024, 4096], dtype=torch.float32))
        self.register_buffer('_tensor_constant1', torch.empty([4096, 1024], dtype=torch.float32))
        self.register_buffer('_tensor_constant2', torch.empty([1024, 4096], dtype=torch.float32))
        self.register_buffer('_tensor_constant3', torch.empty([4096, 1024], dtype=torch.float32))
        self._param_constant0 = torch.nn.Parameter(torch.empty([4096], dtype=torch.float32))
        self._param_constant1 = torch.nn.Parameter(torch.empty([1024], dtype=torch.float32))

    def forward(self, primals, tangents):
        primals_1, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
        view = torch.ops.aten.view(primals_1, [91, 1024]);  primals_1 = None
        _tensor_constant0 = self._tensor_constant0
        mm =, _tensor_constant0);  _tensor_constant0 = None
        _unsafe_view = torch.ops.aten._unsafe_view(mm, [7, 13, 4096]);  mm = None
        _param_constant0 = self._param_constant0
        add_ = torch.ops.aten.add_(_unsafe_view, _param_constant0);  _unsafe_view = _param_constant0 = None
        view_1 = torch.ops.aten.view(add_, [91, 4096]);  add_ = None
        _tensor_constant1 = self._tensor_constant1
        mm_1 =, _tensor_constant1);  _tensor_constant1 = None
        _unsafe_view_1 = torch.ops.aten._unsafe_view(mm_1, [7, 13, 1024]);  mm_1 = None
        _param_constant0_1 = self._param_constant0
        _param_constant1 = self._param_constant1
        add__1 = torch.ops.aten.add_(_unsafe_view_1, _param_constant1);  _unsafe_view_1 = _param_constant1 = None
        sum_1 = torch.ops.aten.sum(tangents_1, [0, 1], True)
        view_2 = torch.ops.aten.view(sum_1, [1024]);  sum_1 = None
        view_3 = torch.ops.aten.view(tangents_1, [91, 1024]);  tangents_1 = None
        t = torch.ops.aten.t(view_3)
        mm_2 =, view_1);  t = view_1 = None
        t_1 = torch.ops.aten.t(mm_2);  mm_2 = None
        _tensor_constant2 = self._tensor_constant2
        mm_3 =, _tensor_constant2);  view_3 = _tensor_constant2 = None
        view_4 = torch.ops.aten.view(mm_3, [7, 13, 4096]);  mm_3 = None
        sum_2 = torch.ops.aten.sum(view_4, [0, 1], True)
        view_5 = torch.ops.aten.view(sum_2, [4096]);  sum_2 = None
        view_6 = torch.ops.aten.view(view_4, [91, 4096]);  view_4 = None
        t_2 = torch.ops.aten.t(view_6)
        mm_4 =, view);  t_2 = view = None
        t_3 = torch.ops.aten.t(mm_4);  mm_4 = None
        _tensor_constant3 = self._tensor_constant3
        mm_5 =, _tensor_constant3);  view_6 = _tensor_constant3 = None
        view_7 = torch.ops.aten.view(mm_5, [7, 13, 1024]);  mm_5 = None
        return pytree.tree_unflatten([add__1, view_7], self._out_spec)

PyTorch 2.0

in PyTorch 2.0, I don’t know why these two operators are erased from the forward function.

import torch
from math import inf
from math import nan
NoneType = type(None)
import torch
from torch import device
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree

from torch.nn import *
class FxModule(torch.nn.Module):
    def __init__(self):
        self.register_buffer('_tensor_constant0', torch.empty([1024, 4096], dtype=torch.float32))
        self.register_buffer('_tensor_constant1', torch.empty([4096, 1024], dtype=torch.float32))
        self.register_buffer('_tensor_constant2', torch.empty([1024, 4096], dtype=torch.float32))
        self.register_buffer('_tensor_constant3', torch.empty([4096, 1024], dtype=torch.float32))
        self._param_constant0 = torch.nn.Parameter(torch.empty([4096], dtype=torch.float32))
        self._param_constant1 = torch.nn.Parameter(torch.empty([1024], dtype=torch.float32))

    def forward(self, primals, tangents):
        primals_1, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
        view_default = torch.ops.aten.view.default(primals_1, [91, 1024]);  primals_1 = None
        _param_constant0 = self._param_constant0
        _tensor_constant0 = self._tensor_constant0
        addmm_default = torch.ops.aten.addmm.default(_param_constant0, view_default, _tensor_constant0);  _param_constant0 = view_default = _tensor_constant0 = None
        view_default_1 = torch.ops.aten.view.default(addmm_default, [7, 13, 4096]);  addmm_default = None
        view_default_2 = torch.ops.aten.view.default(view_default_1, [91, 4096]);  view_default_1 = None
        _param_constant0_1 = self._param_constant0
        _param_constant1 = self._param_constant1
        _tensor_constant1 = self._tensor_constant1
        addmm_default_1 = torch.ops.aten.addmm.default(_param_constant1, view_default_2, _tensor_constant1);  _param_constant1 = view_default_2 = _tensor_constant1 = None
        view_default_3 = torch.ops.aten.view.default(addmm_default_1, [7, 13, 1024]);  addmm_default_1 = None
        is_same_size_default = torch.ops.aten.is_same_size.default(view_default_3, tangents_1)
        view_default_4 = torch.ops.aten.view.default(tangents_1, [91, 1024]);  tangents_1 = None
        _tensor_constant2 = self._tensor_constant2
        mm_default =, _tensor_constant2);  view_default_4 = _tensor_constant2 = None
        view_default_5 = torch.ops.aten.view.default(mm_default, [7, 13, 4096]);  mm_default = None
        view_default_6 = torch.ops.aten.view.default(view_default_5, [91, 4096]);  view_default_5 = None
        _tensor_constant3 = self._tensor_constant3
        mm_default_1 =, _tensor_constant3);  view_default_6 = _tensor_constant3 = None
        view_default_7 = torch.ops.aten.view.default(mm_default_1, [7, 13, 1024]);  mm_default_1 = None
        return pytree.tree_unflatten([view_default_3, view_default_7], self._out_spec)