A small MPS debugging story

Don't detach when making views; force caller to detach by ezyang · Pull Request #84893 · pytorch/pytorch · GitHub got reverted because it broke MPS, and MPS only: PyTorch CI HUD The failures are a number of tests providing numerically wrong results. When I originally got the bug report, I had no idea how this change could have caused the problem. Here is how I investigated it:

  1. I wanted to reproduce the problem locally. So I had to build PyTorch with MPS. Fortunately, I have an M1 Macbook. But my preexisting build of PyTorch on my laptop did not actually have MPS support. There are no instructions on how to build MPS. I Googled “pytorch mps build instructions” and there was an article Installing and running pytorch on M1 GPUs (Apple metal/MPS) | by Chris Dare🔥 | Medium with a command MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ USE_MPS=1 USE_PYTORCH_METAL=1 python setup.py install but when I ran it MPS was not actually enabled. However, this clued me about USE_MPS so I read CMakeLists.txt to see how the cmake detected MPS support. This showed that xcrun --sdk macosx --show-sdk-path was used to check the version of my SDK, and it had to be sufficiently new. Running this command revealed my SDK was not new enough. It wasn’t entirely clear to me how to update the SDK; I Googled around and eventually it became clear that this SDK was tied to my XCode version. I was tempted to install a new XCode, but I then remembered that you could use xcode-select to change your XCode version. I also remembered FB installs a few different versions of XCode on your machine, so I picked the latest version I had installed (Xcode_13.3.1_fb.app), xcode-select’ed it, and then reran the xcrun command to show that the version was recent enough. cmake configuring successfully configured MPS, so I let the build start going on my laptop.

  2. While I was waiting for the build to finish, I started reading the MPS code to try to understand what was going on. I had preexisting knowledge that MPS has to do some work to support views. The relevant code is in pytorch/View.mm at master · pytorch/pytorch · GitHub ; MPS kernels don’t natively support views, so what they do is instead they lazily gather the data implied from the view right before they actually run any kernel on view. My PR changed how view tensors get constructed (previously, if a view tensor had other references we would shallow copy a view and put the view metadata on that), but the internal asserts that I expected to fire in the event of there being references. So, I thought that there was still some funny business going on here, but it was definitely very strange that my assert hadn’t triggered. The standard failure mode that I had previously fixed in the PR was view operations returning the original tensor rather than returning a new tensor, so I looked for any MPS sites that might have had this problem. But there is only one view operation explicitly bound by MPS: as_strided, and it did the right thing. At this point I was at an impasse.

  3. My build finished. I ran the failing tests from CI and confirmed they were still failing. I printed out the mismatching tensors and noted the first element was correct, but the other elements were zero (suggesting no gather actually took place.) I wanted to print out the MPSGraph and see if it was correct or not, but Googling “print mpsgraph” gave no useful results, so I asked about it on I decided to go to sleep.

  4. In the morning, I discovered that Nikita Shulga knew about an SPI [graph dump] that could be used to print the contents of a Metal operation MPSGraph* graph. I started playing around with printing out the contents of the MPSGraph (sometimes my printed graphs were empty, because I was printing them before the logic had actually added the relevant operations to the graph), as well as minimizing the test case. The simplest test case involved a [:,:1] indexing operation; I used TORCH_SHOW_DISPATCH_TRACE=1 to find out what underlying ATen op was triggered by this case (slice). I observed that not all view operations caused problems, but torch.ops.aten.slice(x, 1, 0, 1) was enough to cause a problem. Printing out MPS tensor would cause the problem to go away, but transferring it to CPU first and then printing it was fine. The printed MPSGraphs looked fine, nothing obviously wrong with them. At this point, it became clear that the problem could not be related to directly returning input tensor from view ops, as the slice always has to return a fresh tensor.

  5. I decided that I wanted to compare the MPSGraph from the working and non-working case. To do this, I would need to back out the change from my diff. I opened up the diff to look for what part I would need to change. But while reading over the diff again, I noticed that I had changed an self.as_strided(…) call to be non-dispatched directly to as_strided_tensorimpl(self, …), inside slice (I had totally forgotten about this). This looked like a smoking gun. Restoring it as a dispatched op fixed the problem. Problem solved!

Root cause in the end: as_strided_tensorimpl doesn’t cause MPS to crash, but in the case of MPS it skips the initialization of important state (remember, as_strided is overridden by MPS) that caused the problem. This means that, in general, view operations must dispatch to as_strided. I plan to add a lint for this.

The debugging process took a long time because I had an incorrect conception of what the bug was at the start. I had to get enough evidence to change this incorrect prior. There are a number of ways I could have found the bug more quickly:

  • I could have reread my diff at the beginning, and noticed that slice failed in CI, and I had also modified slice in my diff. The slice change was remarked upon in code review; I could have separated it out into a separate diff.

  • I could have reverted only the substantive change and noticed that this would not actually fix the problem. In the morning, I was considering hacking up the change to do the old behavior if the tensor was MPS to get on with my life; this would have also clued me that it wasn’t working.

On the plus side, I know how to build MPS locally on my mac (TBH, this was the most time consuming step!), have a much better understanding on how MPS views work, and I know how to debug MPSGraphs! So it is hard to call the time poking around there wasted.

5 Likes

Hi, do you mind share your building process for pytorch2.0 on macbook?

It is step one on the post!

Sorry for missing this info. Thanks!