Hello,
I run into a AttributeError while working on my tensor wrapper subclass. If you are curious, you can read here about the ultimate goal.
The tensor below is a minimal example to reproduce the error. Basically I use a tensor wrapper subclass that stores the actual data in the elem
propoerty and does all the comptuation on it. Furthermore, I claim that each node might have two optional attributes: max_grad
and min_grad
. They don’t make any sense in this example but it’s just something I need for my actual goal.
The tensor I write is only thought to be used to track gradients and thus is only used during backprop e.g. loss.backward(gradients=MyTensor(...))
.
I have a wrap
function and a unwrap
function and since there are many functions that are called during backprop which I don’t overwrite but still affect the underlying tensors, like e…g squeeze()
or view()
, I have to apply the functions to elem
and max_grad
and min_grad
- simply to keep the correct shape etc.
So if args
contains an element of type MyTensor
then unwrap()
will insert a 3-tuple in its place. E.g. (Mytensor,A,B) => ((MyTensor,None,None),A,B)
or (Mytensor,A,B) => ((MyTensor,Mytensor,Mytensor),A,B)
.
One has to take special care of this, as you can see in the implementation below.
Sometimes the general approach of run_with_usual_semantic()
doesn’t work because the return value of func
or the shape of args
might be too complex. In such a case I also just capture the function dispatch call and handle it separatley e.g. torch.ops.aten.max.dim
.
Note: You can completely remove all logic related to min_grad
and max_grad
and still get the error. I chose to keep it because it’s way closer to my actual project. Feel free to also give feedback on that. See Appendix.
MyTensor.py:
import torch
from torch.utils._pytree import tree_map
from functools import partial
from torch.return_types import max as MaxReturnType
class MyTensor(torch.Tensor):
def __new__(cls, t, max_grad=None, min_grad=None, verbose_level=0):
# We would need to define autograd on this ring, inception!
assert t.requires_grad == False
res = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, # New class: MyTensor
size=t.size(),
strides=t.stride(),
storage_offset=0,
dtype=t.dtype,
layout=t.layout,
device=t.device,
requires_grad=False,
)
cls.verbose_level = verbose_level
cls.log_dict = {}
return res
def __init__(self, t, max_grad=None, min_grad=None, verbose_level=0):
self.elem = t
self.max_grad = max_grad
self.min_grad = min_grad
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if cls.verbose_level >= 1:
print(f"Function called: {func.__name__}")
return super().__torch_function__(func, types, args, kwargs)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
if isinstance(t, cls):
return t.elem, t.max_grad, t.min_grad
else:
return t
def wrap(t, max_grad=None, min_grad=None):
if isinstance(t, torch.Tensor):
wrapped = cls(
t,
max_grad=max_grad,
min_grad=min_grad,
verbose_level=cls.verbose_level,
)
return wrapped
else:
return t
def run_with_usual_semantic():
args_org = tree_map(unwrap, args)
kwargs_ = tree_map(unwrap, kwargs)
# Assumption: Only the first argument is ever of type MyTensor.
# This is just for this minimal example.
if isinstance(args[0], cls) and not isinstance(args[0], torch.Tensor):
t, t_max, t_min = args_org[0]
else:
t = args_org[0]
t_max, t_min = None, None
# Run func for t
args_org = (t, *args_org[1:])
res = func(*args_org, **kwargs_)
# Run func for t_max
res_max = None
if t_max is not None:
args_ = (t_max, *args_org[1:])
res_max = func(*args_, **kwargs_)
# Run func for t_min
res_min = None
if t_min is not None:
args_ = (t_min, *args_org[1:])
res_min = func(*args_, **kwargs_)
wrap_partial = partial(wrap, max_grad=res_max, min_grad=res_min)
return tree_map(wrap_partial, res)
# Overwrite torch.max with dim argument
if func is torch.ops.aten.max.dim:
if cls.verbose_level >= 1:
print(f"Function dispat: {func.__name__}")
# Assumption: kwargs never hold a MyTensor instance
kwargs = tree_map(unwrap, kwargs)
if isinstance(args[0], cls):
args_tmp = tree_map(unwrap, args)
t, t_max, t_min = args_tmp[0]
else:
args_tmp = tree_map(unwrap, args)
t = args_tmp[0]
t_max, t_min = None, None
args_ = (t, *args[1:])
res = func(*args_, **kwargs)
res_max, res_min = None, None
if t_max is not None and t_min is not None:
args_max = (t_max, *args[1:])
args_min = (t_min, *args[1:])
res_max = torch.ops.aten.max(*args_max, **kwargs)[0]
res_min = torch.ops.aten.max(*args_min, **kwargs)[0]
res = wrap(res, max_grad=res_max, min_grad=res_min)
return MaxReturnType((res[0], res[1]))
if cls.verbose_level >= 1:
print(f"Function dispat: {func.__name__}")
return run_with_usual_semantic()
def __repr__(self):
return f"MyTensor({self.elem})"
(thanks to albanD for the example code)
I use it like this:
run.py:
import torch
from MyTensor import MyTensor
A = MyTensor(torch.ones((3, 3, 3)))
A.max_grad = torch.ones((3, 3, 3)) * 2
A.min_grad = torch.ones((3, 3, 3)) * -2
print("Test 1: torch.max(A, 1).values")
print(torch.max(A, 1).values)
print("Test 2: torch.max(A, 1, keepdim=True).values")
print(torch.max(A, 1, keepdim=True).values)
and I get
$ python run.py
Test 1: torch.max(A, 1).values
Traceback (most recent call last):
File "<path>/fo.py", line 10, in <module>
print(torch.max(A, 1)).values
File "<path>/MyTensor.py", line 130, in __repr__
return f"MyTensor({self.elem})"
AttributeError: 'MyTensor' object has no attribute 'elem'
To me it looks like the MyTensor object is created but not initialized, why else would it not have the .elem
property? But that doesn’t make much sense to me and if that’s the case, it must have to do with something about the underlying torch implementation.
Btw, I was able to overwrite a lot of things like mm, bmm, add, sum etc. using the general tensor subclass wrapper code above (without the max ofc) and it worked.
Enviornment:
- I’m on a M1 mac
- Python 3.10.14
- torch 2.5.0
Thanks in advance!
Appendix:
MyTensor without min/max grad:
import torch
from torch.utils._pytree import tree_map
from functools import partial
from torch.return_types import max as MaxReturnType
class MyTensor(torch.Tensor):
def __new__(cls, t, max_grad=None, min_grad=None, verbose_level=0):
# We would need to define autograd on this ring, inception!
assert t.requires_grad == False
res = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, # New class: MyTensor
size=t.size(),
strides=t.stride(),
storage_offset=0,
dtype=t.dtype,
layout=t.layout,
device=t.device,
requires_grad=False,
)
cls.verbose_level = verbose_level
cls.log_dict = {}
return res
def __init__(self, t, verbose_level=0):
self.elem = t
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if cls.verbose_level >= 1:
print(f"Function called: {func.__name__}")
return super().__torch_function__(func, types, args, kwargs)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
if isinstance(t, cls):
return t.elem
else:
return t
def wrap(t):
if isinstance(t, torch.Tensor):
wrapped = cls(
t,
verbose_level=cls.verbose_level,
)
return wrapped
else:
return t
def run_with_usual_semantic():
args_org = tree_map(unwrap, args)
kwargs_ = tree_map(unwrap, kwargs)
res = func(*args_org, **kwargs_)
return tree_map(wrap, res)
# Overwrite torch.max with dim argument
if func is torch.ops.aten.max.dim:
if cls.verbose_level >= 1:
print(f"Function dispat: {func.__name__}")
# Assumption: kwargs never hold a MyTensor instance
kwargs = tree_map(unwrap, kwargs)
args_ = tree_map(unwrap, args)
res = func(*args_, **kwargs)
res = wrap(res)
return MaxReturnType((res[0], res[1]))
if cls.verbose_level >= 1:
print(f"Function dispat: {func.__name__}")
return run_with_usual_semantic()
def __repr__(self):
return f"MyTensor({self.elem})"