I try to use custom op implementation with enabled compiled autograd and enabled autocast.
My op forward and backward are decorated with torch.cuda.amp decorators: custom_bwd, custom_fwd
Receiving the following error:
torch._dynamo.exc.InternalTorchDynamoError: dtype must be a torch.dtype (got type)
from user code:
File "<eval_with_key.8", line 16, in forward
call_backward = torch__dynamo_external_utils_call_backward(getitem_5, (getitem_3,), mm); getitem_5 = getitem_3 = mm = None
File "/home/tdomagala/.venvs/foo/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 77, in call_backward grads = backward_fn(FakeContext(saved_tensors), *args)
File "/home/tdomagala/.venvs/foo/lib/python3.10/site/packages/torch/cuda/amp/autocast_mode.py", line 141, in decorate_bwd
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
Example code:
import torch
from torch._dynamo.utils import maybe_enable_compiled_autograd
from torch.cuda.amp import custom_bwd, custom_fwd
class CustomOp(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, i):
result = i @ i
print(f'{i.dtype=}, {result.dtype=}')
ctx.label_smoothing = 0.5
ctx.save_for_backward(result)
return result
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
label_smoothing = ctx.label_smoothing
if label_smoothing > 0:
result *= 2.0
else:
result *= 3.0
return grad_output * result
def custom_function(x):
x = CustomOp.apply(x)
return x
def fn(x):
x = x * x
y = custom_function(x)
res = x @ y
return res
x = torch.randn((1000, 1000), dtype=torch.float32, device="cpu").requires_grad_(True)
with torch.autocast(dtype=torch.bfloat16, device_type="cpu", enabled=True):
with maybe_enable_compiled_autograd(True):
fn = torch.compile(fn)
r = fn(x)
print("r.dtype", r.dtype)
print(r)
l = torch.sum(r)
l.backward()
When enabling compiled autograd having issue with dynamic attributes of FakeCtx. Not observing this problem without compiled autograd.