[Solved] Overwriting max() leads to AttributeError

Hello,

I run into a AttributeError while working on my tensor wrapper subclass. If you are curious, you can read here about the ultimate goal.

The tensor below is a minimal example to reproduce the error. Basically I use a tensor wrapper subclass that stores the actual data in the elem propoerty and does all the comptuation on it. Furthermore, I claim that each node might have two optional attributes: max_grad and min_grad. They don’t make any sense in this example but it’s just something I need for my actual goal.

The tensor I write is only thought to be used to track gradients and thus is only used during backprop e.g. loss.backward(gradients=MyTensor(...)).

I have a wrap function and a unwrap function and since there are many functions that are called during backprop which I don’t overwrite but still affect the underlying tensors, like e…g squeeze() or view(), I have to apply the functions to elem and max_grad and min_grad - simply to keep the correct shape etc.

So if args contains an element of type MyTensor then unwrap() will insert a 3-tuple in its place. E.g. (Mytensor,A,B) => ((MyTensor,None,None),A,B) or (Mytensor,A,B) => ((MyTensor,Mytensor,Mytensor),A,B).

One has to take special care of this, as you can see in the implementation below.

Sometimes the general approach of run_with_usual_semantic() doesn’t work because the return value of func or the shape of args might be too complex. In such a case I also just capture the function dispatch call and handle it separatley e.g. torch.ops.aten.max.dim.

Note: You can completely remove all logic related to min_grad and max_grad and still get the error. I chose to keep it because it’s way closer to my actual project. Feel free to also give feedback on that. See Appendix.

MyTensor.py:

import torch
from torch.utils._pytree import tree_map
from functools import partial
from torch.return_types import max as MaxReturnType


class MyTensor(torch.Tensor):
    def __new__(cls, t, max_grad=None, min_grad=None, verbose_level=0):
        # We would need to define autograd on this ring, inception!
        assert t.requires_grad == False

        res = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
            cls,  # New class: MyTensor
            size=t.size(),
            strides=t.stride(),
            storage_offset=0,
            dtype=t.dtype,
            layout=t.layout,
            device=t.device,
            requires_grad=False,
        )

        cls.verbose_level = verbose_level
        cls.log_dict = {}

        return res

    def __init__(self, t, max_grad=None, min_grad=None, verbose_level=0):
        self.elem = t
        self.max_grad = max_grad
        self.min_grad = min_grad

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if cls.verbose_level >= 1:
            print(f"Function called: {func.__name__}")

        return super().__torch_function__(func, types, args, kwargs)

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            if isinstance(t, cls):
                return t.elem, t.max_grad, t.min_grad
            else:
                return t

        def wrap(t, max_grad=None, min_grad=None):
            if isinstance(t, torch.Tensor):
                wrapped = cls(
                    t,
                    max_grad=max_grad,
                    min_grad=min_grad,
                    verbose_level=cls.verbose_level,
                )
                return wrapped
            else:
                return t

        def run_with_usual_semantic():
            args_org = tree_map(unwrap, args)
            kwargs_ = tree_map(unwrap, kwargs)

            # Assumption: Only the first argument is ever of type MyTensor.
            # This is just for this minimal example.
            if isinstance(args[0], cls) and not isinstance(args[0], torch.Tensor):
                t, t_max, t_min = args_org[0]
            else:
                t = args_org[0]
                t_max, t_min = None, None

            # Run func for t
            args_org = (t, *args_org[1:])
            res = func(*args_org, **kwargs_)

            # Run func for t_max
            res_max = None
            if t_max is not None:
                args_ = (t_max, *args_org[1:])
                res_max = func(*args_, **kwargs_)

            # Run func for t_min
            res_min = None
            if t_min is not None:
                args_ = (t_min, *args_org[1:])
                res_min = func(*args_, **kwargs_)

            wrap_partial = partial(wrap, max_grad=res_max, min_grad=res_min)

            return tree_map(wrap_partial, res)

        # Overwrite torch.max with dim argument
        if func is torch.ops.aten.max.dim:
            if cls.verbose_level >= 1:
                print(f"Function dispat: {func.__name__}")

            # Assumption: kwargs never hold a MyTensor instance
            kwargs = tree_map(unwrap, kwargs)

            if isinstance(args[0], cls):
                args_tmp = tree_map(unwrap, args)
                t, t_max, t_min = args_tmp[0]
            else:
                args_tmp = tree_map(unwrap, args)
                t = args_tmp[0]
                t_max, t_min = None, None

            args_ = (t, *args[1:])

            res = func(*args_, **kwargs)

            res_max, res_min = None, None
            if t_max is not None and t_min is not None:
                args_max = (t_max, *args[1:])
                args_min = (t_min, *args[1:])

                res_max = torch.ops.aten.max(*args_max, **kwargs)[0]
                res_min = torch.ops.aten.max(*args_min, **kwargs)[0]

            res = wrap(res, max_grad=res_max, min_grad=res_min)

            return MaxReturnType((res[0], res[1]))

        if cls.verbose_level >= 1:
            print(f"Function dispat: {func.__name__}")

        return run_with_usual_semantic()

    def __repr__(self):
        return f"MyTensor({self.elem})"

(thanks to albanD for the example code)

I use it like this:

run.py:

import torch
from MyTensor import MyTensor

A = MyTensor(torch.ones((3, 3, 3)))
A.max_grad = torch.ones((3, 3, 3)) * 2
A.min_grad = torch.ones((3, 3, 3)) * -2


print("Test 1: torch.max(A, 1).values")
print(torch.max(A, 1).values)


print("Test 2: torch.max(A, 1, keepdim=True).values")
print(torch.max(A, 1, keepdim=True).values)

and I get

$ python run.py
Test 1: torch.max(A, 1).values
Traceback (most recent call last):
  File "<path>/fo.py", line 10, in <module>
    print(torch.max(A, 1)).values
  File "<path>/MyTensor.py", line 130, in __repr__
    return f"MyTensor({self.elem})"
AttributeError: 'MyTensor' object has no attribute 'elem'

To me it looks like the MyTensor object is created but not initialized, why else would it not have the .elem property? But that doesn’t make much sense to me and if that’s the case, it must have to do with something about the underlying torch implementation.

Btw, I was able to overwrite a lot of things like mm, bmm, add, sum etc. using the general tensor subclass wrapper code above (without the max ofc) and it worked.

Enviornment:

  • I’m on a M1 mac
  • Python 3.10.14
  • torch 2.5.0

Thanks in advance! :slight_smile:

Appendix:

MyTensor without min/max grad:

import torch
from torch.utils._pytree import tree_map
from functools import partial
from torch.return_types import max as MaxReturnType


class MyTensor(torch.Tensor):
    def __new__(cls, t, max_grad=None, min_grad=None, verbose_level=0):
        # We would need to define autograd on this ring, inception!
        assert t.requires_grad == False

        res = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
            cls,  # New class: MyTensor
            size=t.size(),
            strides=t.stride(),
            storage_offset=0,
            dtype=t.dtype,
            layout=t.layout,
            device=t.device,
            requires_grad=False,
        )

        cls.verbose_level = verbose_level
        cls.log_dict = {}

        return res

    def __init__(self, t, verbose_level=0):
        self.elem = t

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if cls.verbose_level >= 1:
            print(f"Function called: {func.__name__}")

        return super().__torch_function__(func, types, args, kwargs)

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            if isinstance(t, cls):
                return t.elem
            else:
                return t

        def wrap(t):
            if isinstance(t, torch.Tensor):
                wrapped = cls(
                    t,
                    verbose_level=cls.verbose_level,
                )
                return wrapped
            else:
                return t

        def run_with_usual_semantic():
            args_org = tree_map(unwrap, args)
            kwargs_ = tree_map(unwrap, kwargs)

            res = func(*args_org, **kwargs_)

            return tree_map(wrap, res)

        # Overwrite torch.max with dim argument
        if func is torch.ops.aten.max.dim:
            if cls.verbose_level >= 1:
                print(f"Function dispat: {func.__name__}")

            # Assumption: kwargs never hold a MyTensor instance
            kwargs = tree_map(unwrap, kwargs)
            args_ = tree_map(unwrap, args)

            res = func(*args_, **kwargs)

            res = wrap(res)

            return MaxReturnType((res[0], res[1]))

        if cls.verbose_level >= 1:
            print(f"Function dispat: {func.__name__}")

        return run_with_usual_semantic()

    def __repr__(self):
        return f"MyTensor({self.elem})"

Solved. This was older code I did some time ago that I picked up. Back then I played around with the whole topic of subclassing and somehow ended up using __torch_function__ together with __torch_dispatch__.

Solution: Remove __torch_function__, that’s it.