Debugging story: The case of the garbage text generation

To begin this story, let’s give some context: part of my job at IBM research is investigating whether IBM’s internal models play nice with torch.compile(). One of our models is an encoder-decoder transformer, and I was testing whether auto-regressive generation (aka text generation token by token) was working properly.

Given this is a debug story, you already know the beginning: our internal model, let’s call it S from now on, gave identical outputs on eager and static shapes compilation, but spit garbage out when turning dynamic_shapes=True on. As the example I was running was summarization, here’s an example sentence running on both eager and compiled with dynamic shapes:


input: "summarize: A recent study has found a substantial presence in the ultrafine particulate matter PM1, a dangerous pollutant, in Central Delhi.The PM 1 particle is 70 times finer than the thickness of a human hair, goes directly into the bloodstream and is potentially more dangerous than the more well known PM2.5 and PM10 pollutants. And no one knows how much of it is in the air we breathe.The Lodhi Road monitoring station of the System of Air Quality and Weather Forecasting And Research (SAFAR) under the Ministry of Earth Sciences has recorded the average volume of PM1 during summer, winter and the monsoon at about 46, 49 and 20 micrograms per cubic metres, respectively. The data from the year 2016 is collected in the only station that has the technology, says Gufran Beig, project director of SAFAR.The safe standards of PM1 have not been defined yet, in the absence of which its potentially harmful effects have not been documented. Globally, too, there might not be a standard but PM1 is considered the most dangerous among all particulate matter, particularly because of its size. It measures around 1 micron or less in diameter and can enter deep into the lungs and bloodstream.According to a Central Pollution Control Board 2010 study, small particulate matter can reach the alveolar region, causing heart ailments. These fine particles cover a large surface area, absorb toxic compounds such as heavy metals and organic compounds with high carbon content, the study said.These particles? spewed primarily from vehicles, factories and construction sites? are not dispersed and stay suspended in the air.But why is the air in Central Delhi flooded with these particulates??PM1 is a major product of vehicular combustion. Roads in and around the Lodhi Road area like other parts of Lutyens? Delhi see a huge flow of vehicles. This might be the reason behind the prevalence of this finer particulate matter. It also depends on where the station is located, the dispersion capacity, meteorological factors, among others,? Dr Dipankar Saha, additional director and head of the Air Laboratory at the CPCB told Hindustan Times."

eager output: "A recent study has found a substantial presence in the ultrafine particulate matter PM1, which is 70 times finer than the thickness of a human hair, in Delhi\'s Lodhi Road area. It measures around 1 micron or less in diameter and can enter deep into the lungs and bloodstream, according to a Central Pollution Control Board 2010 study. "The safe standards of PM1 have not been defined yet, in the absence of which its potentially harmful effects have not been documented," Beig added."

compile output: "A recent study has found a.05. Sw congratulations, Treat variance admissionplace billet cu ever cu cu coll Rack impliesCommitatiuserbianrataplaceempreschi Weihnachts Gottestakeche Gre Kyoto IhnenKombrazeit”).sheltered admittedsynchronous variance calibr Comme Zoo obligatoire obligatoire aluminiu Bluffempre possibleURI variance cogn England variancepitiris Bra variancesent servesLOS variance bargain mattress cogn („ variancebacked Pil variance Rab Comme variancemodeled disponible whistle Blecru rowsani Bun 2014-11- lol variancepool1998)usuallypappulsynchronous variance variance variance oblige?)phobiaeven convo DiviGuests Whisk varianceUMP Comme variance"

As you can see, the compile output is fine for the first few tokens and then diverges into garbage. So, debugging time!

Of course, the first thing you do in these cases is go to the troubleshooting page for compilation issues: TorchDynamo Troubleshooting — PyTorch 2.0 documentation. Scrolling through the page, I try the following:

  1. To find whether the error is in Dynamo or Inductor, I first try compiling with the eager and aot_eager backends. As both of them return correct results, now I know the error is in Inductor!

  2. The instructions then guide me towards the Inductor accuracy minifier, which makes sense. I have an accuracy issue, it’s due to Inductor, so let’s run the minifier, get the minimal Repro, get the issue to Issues · pytorch/pytorch · GitHub, and we’re done!

Or… the minifier breaks on model S code and I’m stuck with a bug and no easy way to find what operation of the thousands of operations in the model is actually the cause. So here I fill the following issue: torch.compile() accuracy minifier breaks when using dynamic shapes · Issue #96971 · pytorch/pytorch · GitHub. Everything I explain next is also covered in this issue.

At this point, @ezyang starts helping with debugging the issue by suggesting different things I can try, such as compiling with assume_static_by_default = True and marking the dynamic dimension explicitely (that’d be the sequence length of the summarized text), as well as the other suggestions from the troubleshooting guide. I try compiling the model with assume_static_by_default, but that also breaks Inductor with a similar bug to the minifier, so that avenue is also not a good one. Here’s where things become hairier: given model S is internal to IBM, I can’t share the source code publicly. At this point, the only way to help @ezyang to debug the issue is the old-fashioned way: creating a manual repro by finding exactly where the graph compilation is creating kernels that are incorrect and writing new code that will trigger exactly the same error but is open-source.

So, given my constraints, I start by adding print() statements around the code trying to find which intermediate tensors are almost equal and where the divergence happens. I add prints() after every encoder and decoder layer, as well as in between the stacks, beginning and end of the model. Of course, this also breaks Inductor, which fails to compile at all with the new graph breaks. I keep reducing the number of prints until the model manages to compile, and I find that the issue is in the decoder stack (which makes sense given it’s the only dynamic part of the model).

Once I get here, though, I get stuck on how to create an open-source repro, but that’s where torch.nn.TransformerDecoderLayer comes to the rescue. I create a stack that mimics model S’s as much as possible with this layer, and, lo and behold, it works just fine with torch.compile(). Therefore, I just need to do a “diff” between our closed-source model and the open-source implementation to find the issue, and I finally find the culprit: we are using a kind of relative positional embeddings known as AliBi (https://arxiv.org/pdf/2108.12409.pdf), and apparently these don’t play well with Inductor. I add these to TransformerDecoderLayer, and bam, same accuracy issues! Bug reproduced, I send the repro code after legal review to the github issue.

This happens on a Friday evening, so let’s move forward to the next Monday morning. @ezyang and I have a virtual meeting, where we first work on making the repro run and compile faster while still keeping the accuracy error in order to increase debug speed. The result of that work is Add a unit test for negative torch.arange() incorrect numerical behavior with dynamic shapes by ani300 · Pull Request #97926 · pytorch/pytorch · GitHub, a new unit test to check if the error happens again in the future. Once that’s done, @ezyang shares his expert hypothesis that Inductor might not have enough guards for recompilation and teaches me about TORCH_LOGS, a very recent addition to the suite of debugging tools for compile() that can print which guards are generated by the stack. So we try it and we find that, indeed, there is a guard missing related to the use of torch.arange() with negative ranges, which model S’s implementation of AliBi happens to use. Once we find this, the fix is very quickly implemented by @ezyang in Propagate inductor guards to ShapeEnv by ezyang · Pull Request #97777 · pytorch/pytorch · GitHub.

So, after a week and a half, the issue is solved and our model is running almost 2x faster during auto-regressive generation with no severe accuracy issues! I hope this write-up helps anyone who find themselves debugging similar issues by providing a summary of the steps and pitfalls I found along the way. I’d like to thank @ezyang for his help during the whole process.

14 Likes