There has been some questions around the support for torchao quantization flow(quantize_
, autoquant
, QAT
etc.) and pt2e quantization flow(prepare_pt2e
, prepare_qat_pt2e
, convert_pt2e
) by PyTorch. This note is trying to give a brief summary of the support plan and how people can choose between the two types of flow.
Also for the older eager mode quantization and fx graph mode quantization our plan is to deprecate them, see Torch.ao.quantization Migration Plan for more details.
Support
Currently, we are committed to support both torchao and pt2e quantization flows as they support different use cases, but we’ll aim for sharing common building blocks between the two flows, for example, quantize dequantize ops, observers etc. Support means to maintain and develop new features as requested by users.
Recommendations for Modeling Users
For pt2e quantization, we are also moving the implementation to torchao and having more shared components between pt2e quantization and torchao quantization (specifically they will be using the same quantization primitive ops (quantize_affine/dequantize_affine)). In 2025 H1, we plan to support blockwise quantization and also codebook quantization with pt2e quantization flow as well.
Please let us know if you have any questions.
1 Like
Hi @jerryzh168, thanks for the update!
As a Torch user in the embedded space, I’d like to ask a pair of things. Do you have plans to constrain the export feature in torchao quantization? As far I understand, you recommend to exclude the export in the torchao flow for speedup reasons. However, it’s still interesting to export models quantized with advanced techniques to deploy them in custom backends.
Additionally, it seems that during export with torch.export() the ops of the generated IR is dependent on the package we use. For example, we obtain prims ops when we export with torchao, but ATen (and Core ATen) ops with pt2e. I’ve read some discussions stating that’s possible to control how much an op is decomposed, but today it’s a bit opaque for the users. Do you intend to expose somehow the degree of decomposition to, for example, decompose the ATen ops to prim ops generated with pt2e, or viceversa?
I’m happy to help!
-
we do support export for torchao quantization as well, a good example is Int8DynamicActivationInt4WeightConfig, you can export the model and see the quantize/dequantize ops (used in pt2e) show up in the exported model: ao/test/integration/test_integration.py at 8f93751cd6533732dcce0cdd336d04a204f2adc0 · pytorch/ao · GitHub, see ao/tutorials/developer_api_guide/export_to_executorch.py at main · pytorch/ao · GitHub for more explanations on how you can preserve a high level op during the export for the op to be consumed (lowered) by proceeding transformations
Do you intend to expose somehow the degree of decomposition to, for example, decompose the ATen ops to prim ops generated with pt2e, or viceversa?
export itself seem to allow people to specify a decomposition table. I’m not very familiar with this part, see Get `aot_autograd`'ed graph without `torch.compile` and freeze constants without Inductor context · Issue #140205 · pytorch/pytorch · GitHub for an example of using a decomposition table after export. my understanding is we first get aten IR with export and then use decomp table to decompose it to prim IR.
but we also have the tutorial mentioned above for preserving specific high level ops: ao/tutorials/developer_api_guide/export_to_executorch.py at main · pytorch/ao · GitHub
please let me know if there is any questions for that.
a side question, where are aten IR and prim IR defined? Pere my understanding, aten IR is defined at pytorch/aten/src/ATen/native/native_functions.yaml at main · pytorch/pytorch · GitHub. And prim IR is defined at Redirecting... , but we could not visit this webpage now.
@jerryzh168 Torchao quantization (Clarification of PyTorch Quantization Flow Support (in pytorch and torchao)), currently works only with linear layer and only leverages ATen kernels and not any backend specific kernels
from OneDNN/ACL on ARM CPU’s.
a)Can we enable the torchao quantization flow with backend specific INT8 kernels(OneDNN / ACL)? (or) This flow is purely hardware specific and will only use ATen generic kernels?
b) Will torchao quantization be extended to other operators as well like conv, activation ops, etc in the future or the focus is strictly linear here.
a) torchao quant can support backend specific int8 kernels, you can expose it through “layout” (for different packing format), an example is CPU layout for int4 weight only quantization: ao/test/integration/test_integration.py at f38c2722d953ea9352268f0f43f0889041423f27 · pytorch/ao · GitHub, see Quantization Overview — torchao 0.9 documentation for a more detailed explanation
b). yeah it can be extended to other ops as we work more on optimizations, ideally it’s driven by specific important model / use cases. let me know if you feel any model is bottlenecked by these ops and we can take a look. one op I have in mind is SPDA, and maybe moe next.
1 Like
Hi @jerryzh168, curious on what’s the road map of pt2e quantization and torchao quantization? Will one replace the other soon and takes all the feature coverage? Or it will co-exist for a very long term?
currently it seems:
if I use pt2e method I don’t have those good features such as block-wise quant and FSDP.
if I use torchao method then I don’t have quant conv2d.
Say if I want both features, and can do some local code customization on top of pytorch, which interface should I go with (torchao or pt2e)?
@YuiHirasawa pt2e quant and torchao quant will co-exist and both be supported long term since they serve difference purposes, we’ll also try to bring some features on parity depends on the needs. block-wise quant is one example that we will support in both flows, here is a test that shows how to do it right now: ao/test/quantization/pt2e/test_quantize_pt2e.py at 2c901b393846ff39d97598abab586d08765f7ea2 · pytorch/ao · GitHub but we might do some more consolidation of the shared code in the future
we did have some request for quantizing conv2d on the torchao (quantize_) API as well, we just haven’t got to work on this yet.
If you want blockwise quant and quant conv2d, I believe you can do it through pt2e quant now,
if you need FSDP in pt2e quant, I’m not exactly sure what needs to be done, are you referring to QAT? we might support FSDP in pt2e QAT as well I think, but @andrewor14 can confirm on that.
Thanks for reply! then looks like the pt2e quant and torchao quant have different emphases. Do you happened have any insights/hints on what’s the focus of each? That would help me understand how to make selection.
One of my random guess is: is it true that if I want to deploy on edge device for inference, I will need to export, and makes pt2e the better choice? But for LLM training torchao is better because of more modern feature?
And thanks for you pt2e example! looks like it has more feature than the torch.ao. Sounds like no matter I choose p2te or torchao quant method, I could think torch.ao as retired?
what’s the focus of each
the post is trying to clarify that, to summarize pt2e is mostly for static quantization use cases, and torchao is for others.
One of my random guess is: is it true that if I want to deploy on edge device for inference, I will need to export, and makes pt2e the better choice? But for LLM training torchao is better because of more modern feature?
it’s true that edge use cases is mostly pt2e quant so far, but we also have edge + LLM use cases that’s using torchao (quantize_) API as well, so it depends on the type of quantization you’d like to do, e.g. static, v.s. dyanmic, weight only, or more advanced ones like AWQ, GPTQ, smoothquant etc.
And thanks for you pt2e example! looks like it has more feature than the torch.ao. Sounds like no matter I choose p2te or torchao quant method, I could think torch.ao as retired?
that’s true, we are depreacting torch.ao.quantization
, more details in Torch.ao.quantization Migration Plan