Compiled autograd with custom ops error

I try to use custom op implementation with enabled compiled autograd and enabled autocast.

My op forward and backward are decorated with torch.cuda.amp decorators: custom_bwd, custom_fwd
Receiving the following error:
torch._dynamo.exc.InternalTorchDynamoError: dtype must be a torch.dtype (got type)

from user code:

File "<eval_with_key.8", line 16, in forward
call_backward = torch__dynamo_external_utils_call_backward(getitem_5, (getitem_3,), mm);  getitem_5 = getitem_3 = mm = None
File "/home/tdomagala/.venvs/foo/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 77, in call_backward grads = backward_fn(FakeContext(saved_tensors), *args)
File "/home/tdomagala/.venvs/foo/lib/python3.10/site/packages/torch/cuda/amp/autocast_mode.py", line 141, in decorate_bwd
  with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):

Example code:

import torch
from torch._dynamo.utils import maybe_enable_compiled_autograd
from torch.cuda.amp import custom_bwd, custom_fwd


class CustomOp(torch.autograd.Function):
	@staticmethod
	@custom_fwd
	def forward(ctx, i):
		result = i @ i
		print(f'{i.dtype=}, {result.dtype=}')
		ctx.label_smoothing = 0.5
		ctx.save_for_backward(result)
		return result

	@staticmethod
	@custom_bwd
	def backward(ctx, grad_output):
		(result,) = ctx.saved_tensors
		label_smoothing = ctx.label_smoothing
		if label_smoothing > 0:
			result *= 2.0
		else:
			result *= 3.0
		return grad_output * result


def custom_function(x):
	x = CustomOp.apply(x)
	return x



def fn(x):
	x = x * x
	y = custom_function(x)
	res = x @ y

	return res


x = torch.randn((1000, 1000), dtype=torch.float32, device="cpu").requires_grad_(True)

with torch.autocast(dtype=torch.bfloat16, device_type="cpu", enabled=True):
	with maybe_enable_compiled_autograd(True):
		fn = torch.compile(fn)
		r = fn(x)
		print("r.dtype", r.dtype)

		print(r)
		l = torch.sum(r)
		l.backward()

When enabling compiled autograd having issue with dynamic attributes of FakeCtx. Not observing this problem without compiled autograd.

issue submitted: [dynamo] torch._dynamo.exc.InternalTorchDynamoError: dtype must be a torch.dtype (got type) · Issue #131154 · pytorch/pytorch · GitHub