Torch.ao.quantization Migration Plan

Goal

The goal for the doc is to lay out the plan for deprecating and migrating quantization flows in torch.ao.quantization.

Note: This is a follow up to Clarification of PyTorch Quantization Flow Support (in pytorch and torchao) to clarify our migration plan for torch.ao.quantization.

What is in torch.ao.quantization

Flow Release Status Features Backends Note
Eager Mode Quantization beta post training static, dynamic and weight only quantization, and quantization aware training (for static quantization), and numeric debugging tool. x86 (fbgemm) and ARM CPU (qnnpack) Quantized operators are using quantized Tensor in C++, that we plan to deprecate
TorchScript Graph Mode Quantization prototype post training static and dynamic quantization x86 (fbgemm) and ARM CPU (qnnpack) Quantized operators are using quantized Tensor in C++, that we plan to deprecate
FX Graph Mode Quantization prototype Post training static, dynamic, weight only, QAT, numeric suite X86 (fbgemm/onednn) and ARM CPU (qnnpack/xnnpack) Quantized operators are using quantized Tensor in C++, that we plan to deprecate
PT2E Quantization prototype Post Training static, dynamic, weight only, QAT, numeric debugger X86 (onednn), ARM CPU (xnnpack), and many other mobile devices (boltnn, qualcomm, apple, turing, jarvis etc.) Using pytorch native Tensors

Flow Support in 2024

For some data points in terms of support, in 2024,

Proposed Support Status

Overall I think we can have the following two statuses:

  • Long Term Support
    • We commit to support the flow long term
    • We commit to fulfilling important feature request from other teams
    • We commit to bug fixes
  • Phasing Out
    • We won’t add new features
    • We only commit to critical bug fixes

Proposed Action Items

For PT2E Quantization, I think it would be better if we move the the implementation to torchao.

For other workflows, I think we can keep them in pytorch for now, we can revisit the plan for deleting code if the usage drops to a certain point.

In terms of how we do the migration, what we agreed on in torchao meeting is the following:

  1. For code that is used by eager and fx mode quantization like observers and fake_quant modules, we can keep these in pytorch/pytorch and import them from torchao

  2. [1-2 weeks] For pt2e flow related code, we plan to duplicate them in torchao repository, new development with happen in torchao repository and older code in pytorch is kept for BC purposes

  3. After we replicate pt2e flow code in torchao, we’ll also ask people to migrate to torchao APIs

  • [2 weeks] Internally, torchao team will take care of changing API imports in fbcode
  • [2 weeks] Externally, we can add a warning in torch.ao.quantization saying this will be deprecated soon, and potentially delete the pt2e code in 1-2 releases, we’ll add deprecation warning for all other workflows as well that have “Phase Out” support status
  1. [not related to migration] We can also have new development such as adding groupwise observer in torchao after we have duplicated the pt2e flow code in torchao

We can target the above to be done by the end of H1 2025.

2 Likes

@jerryzh168 As per Torch.ao.quantization Migration Plan,
PT2E Quantization will be Long Term support moving further from PyTorch team. Right now, we have validated PT2E on ARM CPU’s and it only leverages FP32 kernels for compute. While this helps reducing memory footprint but the performance takes a hit as the compute happens still in FP32 and there are additional overheads.

(prototype) PyTorch 2 Export Post Training Quantization — PyTorch Tutorials 2.6.0+cu124 documentation also confirms in PT2E, the weights are still in fp32 right now and that you might do constant propagation for quantize op to get integer weights in the future.

Can we know the plans and details when Meta will introduce INT8 inference with PT2E. As this is the recommended quantization flow and the others are set for deprecation, it is very critical to leverage INT8 weights using PT2E.

int8 weights and int8 ops can be supported through torch.compile today, please take a look at PyTorch 2 Export Quantization with X86 Backend through Inductor — PyTorch Tutorials 2.6.0+cu124 documentation as an example, relevant int8 ops for x86 backend can be found in pytorch/aten/src/ATen/native/quantized/library.cpp at a0893475ba91f2b5c71b31af2be6b716b584ce48 · pytorch/pytorch · GitHub, where we take int8/uint8 pytorch Tensors along with quantization parameters and output int8/uint8 Tensors.

convert_pt2e folds the quantize op with fp32 weights by default, so should see ”int8 weights → dequant → linear → quant" pattern by default (and then you can fuse the pattern to int8 op) I think, let me know if it does not work.

1 Like