Recently, I’m using PyTorch DTensor to train a module. However, when it comes to custom op (which is a cuda kernel and registered into torch through TORCH_LIBRARY, like here), DTensor fails to execute it and throws an error.
RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
When I dig it deeper, it seems that the custom op does not have a Python DispatchKey, making it impossible to redispatch back to Python and entering the __torch_dispatch__ of DTensor class. All tensor wrapper subclasses have such problems.
I’m sending this post to ask whether my understanding is correct, and if so, is there any way to register a Python DispatchKey for a custom op. Thanks a lot for any reply.
Are you sure that you’re not overriding the Python key for your custom op by any chance?
btw you can use torch._C._dispatch_dump("namespace::op_name") to know all the registrations that happened for a given op (and where they’re from) and torch._C._dispatch_dump_table("namespace::op_name") to see the computed dispatch table for that op (what will get called for each key). That should help you figure out the details of what is happening here.
Thanks a lot for your reply. It solved my question. I register the custom op by the following codes:
TORCH_LIBRARY(myops, m) {
m.def("myadd", &myadd_cpu);
}
When I change to use TORCH_LIBRARY_IMPL(“my_ops”, CPU, m), everything is okay.
But there is another puzzle for me. I dump the dispatch table for the custom op registered in the former way (using torch._C._dispatch_dump_table("namespace::op_name")), it prints out
Python: registered at /root/share/upstream/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:154 [backend fallback]
It seems it has a default fallback implementation for Python Key. Why the former way registering custom op fails entering the __torch_dispatch__ of DTensor, but the later succeeds?
I think that this is because the default key to which things get registered when not specified in TORCH_LIBRARY is CompositeImplicitAutograd (or a similar key) which happens before the Python key in the dispatcher. So your custom kernel gets called before the torch_dispatch handler has a chance to trigger.