Recently, I’ve been working on resolving some numerical precision alignment issues. Take the division operator as an example; the computation yields different results on CPU and CUDA or when expressed using different syntax, as seen in the attached screenshot.
I’m endeavoring to uncover the underlying reasons through various methods, and the first thing that comes to mind is to review the C++ source code or CUDA source code. However, I’ve encountered challenges in understanding the intricacies of PyTorch’s C++ code and locating the corresponding source code. Is there anyone who can help me understand how to learn PyTorch’s C++ source code, particularly how to find the implementation details of C++ operators?
In most cases, I can find the code like these:
but I failed to find the source code of
@shuokay the main reason finding the kernel you’re looking for from there is a pain is because
op.call(self, other) uses the pytorch dispatcher, and dynamally dispatches to the right kernel (more on the dispatcher here: Let’s talk about the PyTorch dispatcher : ezyang’s blog)
The source of truth (well, 99% of the time) for every ATen operator and the names of its CPU/CUDA kernels in the codebase is in native_functions.yaml: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
div.Tensor: I happen to know that’s it’s implemented as a structured kernel (more on those here Codegen and Structured Kernels · pytorch/pytorch Wiki · GitHub).
I can tell because of the
structured_delegate key in native_functions.yaml here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L2058C3-L2058C22
As a structured kernel, the op
aten::div.Tensor has a “meta” function that performs shape error checks and computes the output shape defined here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/BinaryOps.cpp#L174
And it has an implementation that uses TensorIterator, defined here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/BinaryOps.cpp#L448 (from there you should be able to grep for each function name to find it).
@bdhirsh Thank you very much for your reply, it was really helpful. Based on your explanation, I have “debugged” the PyTorch code step by step, and I think I have a deeper understanding of the implementation and dispatch of PyTorch operators now.