I am trying to analyze the training procedure by utilizing torch.fx and aot_autograd.
However, with pytorch 2.0, I cannot capture the gradient calculation of weights.
PyTorch 1.11.0 + functorch 0.1.0
As you can see in the below code, mm_4 and mm_2 are calculating the gradients of weights.
import torch
from torch.nn import *
class FxModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('_tensor_constant0', torch.empty([1024, 4096], dtype=torch.float32))
self.register_buffer('_tensor_constant1', torch.empty([4096, 1024], dtype=torch.float32))
self.register_buffer('_tensor_constant2', torch.empty([1024, 4096], dtype=torch.float32))
self.register_buffer('_tensor_constant3', torch.empty([4096, 1024], dtype=torch.float32))
self._param_constant0 = torch.nn.Parameter(torch.empty([4096], dtype=torch.float32))
self._param_constant1 = torch.nn.Parameter(torch.empty([1024], dtype=torch.float32))
self.load_state_dict(torch.load(r'0615/state_dict.pt'))
def forward(self, primals, tangents):
primals_1, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
view = torch.ops.aten.view(primals_1, [91, 1024]); primals_1 = None
_tensor_constant0 = self._tensor_constant0
mm = torch.ops.aten.mm(view, _tensor_constant0); _tensor_constant0 = None
_unsafe_view = torch.ops.aten._unsafe_view(mm, [7, 13, 4096]); mm = None
_param_constant0 = self._param_constant0
add_ = torch.ops.aten.add_(_unsafe_view, _param_constant0); _unsafe_view = _param_constant0 = None
view_1 = torch.ops.aten.view(add_, [91, 4096]); add_ = None
_tensor_constant1 = self._tensor_constant1
mm_1 = torch.ops.aten.mm(view_1, _tensor_constant1); _tensor_constant1 = None
_unsafe_view_1 = torch.ops.aten._unsafe_view(mm_1, [7, 13, 1024]); mm_1 = None
_param_constant0_1 = self._param_constant0
_param_constant1 = self._param_constant1
add__1 = torch.ops.aten.add_(_unsafe_view_1, _param_constant1); _unsafe_view_1 = _param_constant1 = None
sum_1 = torch.ops.aten.sum(tangents_1, [0, 1], True)
view_2 = torch.ops.aten.view(sum_1, [1024]); sum_1 = None
view_3 = torch.ops.aten.view(tangents_1, [91, 1024]); tangents_1 = None
t = torch.ops.aten.t(view_3)
mm_2 = torch.ops.aten.mm(t, view_1); t = view_1 = None
t_1 = torch.ops.aten.t(mm_2); mm_2 = None
_tensor_constant2 = self._tensor_constant2
mm_3 = torch.ops.aten.mm(view_3, _tensor_constant2); view_3 = _tensor_constant2 = None
view_4 = torch.ops.aten.view(mm_3, [7, 13, 4096]); mm_3 = None
sum_2 = torch.ops.aten.sum(view_4, [0, 1], True)
view_5 = torch.ops.aten.view(sum_2, [4096]); sum_2 = None
view_6 = torch.ops.aten.view(view_4, [91, 4096]); view_4 = None
t_2 = torch.ops.aten.t(view_6)
mm_4 = torch.ops.aten.mm(t_2, view); t_2 = view = None
t_3 = torch.ops.aten.t(mm_4); mm_4 = None
_tensor_constant3 = self._tensor_constant3
mm_5 = torch.ops.aten.mm(view_6, _tensor_constant3); view_6 = _tensor_constant3 = None
view_7 = torch.ops.aten.view(mm_5, [7, 13, 1024]); mm_5 = None
return pytree.tree_unflatten([add__1, view_7], self._out_spec)
PyTorch 2.0
in PyTorch 2.0, I don’t know why these two operators are erased from the forward function.
import torch
from math import inf
from math import nan
NoneType = type(None)
import torch
from torch import device
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch.nn import *
class FxModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('_tensor_constant0', torch.empty([1024, 4096], dtype=torch.float32))
self.register_buffer('_tensor_constant1', torch.empty([4096, 1024], dtype=torch.float32))
self.register_buffer('_tensor_constant2', torch.empty([1024, 4096], dtype=torch.float32))
self.register_buffer('_tensor_constant3', torch.empty([4096, 1024], dtype=torch.float32))
self._param_constant0 = torch.nn.Parameter(torch.empty([4096], dtype=torch.float32))
self._param_constant1 = torch.nn.Parameter(torch.empty([1024], dtype=torch.float32))
self.load_state_dict(torch.load(r'0615_1/state_dict.pt'))
def forward(self, primals, tangents):
primals_1, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
view_default = torch.ops.aten.view.default(primals_1, [91, 1024]); primals_1 = None
_param_constant0 = self._param_constant0
_tensor_constant0 = self._tensor_constant0
addmm_default = torch.ops.aten.addmm.default(_param_constant0, view_default, _tensor_constant0); _param_constant0 = view_default = _tensor_constant0 = None
view_default_1 = torch.ops.aten.view.default(addmm_default, [7, 13, 4096]); addmm_default = None
view_default_2 = torch.ops.aten.view.default(view_default_1, [91, 4096]); view_default_1 = None
_param_constant0_1 = self._param_constant0
_param_constant1 = self._param_constant1
_tensor_constant1 = self._tensor_constant1
addmm_default_1 = torch.ops.aten.addmm.default(_param_constant1, view_default_2, _tensor_constant1); _param_constant1 = view_default_2 = _tensor_constant1 = None
view_default_3 = torch.ops.aten.view.default(addmm_default_1, [7, 13, 1024]); addmm_default_1 = None
is_same_size_default = torch.ops.aten.is_same_size.default(view_default_3, tangents_1)
view_default_4 = torch.ops.aten.view.default(tangents_1, [91, 1024]); tangents_1 = None
_tensor_constant2 = self._tensor_constant2
mm_default = torch.ops.aten.mm.default(view_default_4, _tensor_constant2); view_default_4 = _tensor_constant2 = None
view_default_5 = torch.ops.aten.view.default(mm_default, [7, 13, 4096]); mm_default = None
view_default_6 = torch.ops.aten.view.default(view_default_5, [91, 4096]); view_default_5 = None
_tensor_constant3 = self._tensor_constant3
mm_default_1 = torch.ops.aten.mm.default(view_default_6, _tensor_constant3); view_default_6 = _tensor_constant3 = None
view_default_7 = torch.ops.aten.view.default(mm_default_1, [7, 13, 1024]); mm_default_1 = None
return pytree.tree_unflatten([view_default_3, view_default_7], self._out_spec)