How can I dump the prims IR, triton code, and ptx code when using torch.compile()

Is there a way to get the code at each of the stages?

Please post questions on how to use PyTorch at https://discuss.pytorch.org/.

In this case, what you are looking for is the envvar TORCH_LOGS=output_code. In general, you can see all the available options at TORCH_LOGS=+help

1 Like

You can also enable more internal logs via TORCH_COMPILE_DEBUG, MLIR_ENABLE_DUMP, and LLVM_IR_ENABLE_DUMP environment variables, with the last two coming from triton.