How can I dump the prims IR, triton code, and ptx code when using torch.compile()

Is there a way to get the code at each of the stages?

Please post questions on how to use PyTorch at

In this case, what you are looking for is the envvar TORCH_LOGS=output_code. In general, you can see all the available options at TORCH_LOGS=+help

1 Like

You can also enable more internal logs via TORCH_COMPILE_DEBUG, MLIR_ENABLE_DUMP, and LLVM_IR_ENABLE_DUMP environment variables, with the last two coming from triton.