What's the difference between torch.export / torchserve / executorch / aotinductor?

I’m watching and learning from pytorch conference 2023, when I encountered these concepts. It seems they are all related to deploy models.

To the best of my understanding:

The most confusing thing to me is AOTInductor. It seems to be more of a concept. I don’t know if it is a new backend, or a new repo, or a new functionality.

Can anyone help me out? It would be great to provide some concrete code to show what is AOTInductor.

And any other help for understanding torch.export/torchserve/executorch would also be great!

Might be relevant to @desertfire .

Hi Youkai, your understanding is pretty much all correct, only a few minor corrections.

  1. torch.export is not a function, instead it is a Python module containing torch.export.export (I know …) and other utilities that export needs.
  2. executorch always needs model from torch.export instead of “might use”
  3. I am a bit behind on updates on AoTInductor. It is a prototype now and it is a new functionality to run TorchInductor in Ahead-of-time fashion. I don’t think it will be a new repo. @desertfire will have the definitive answer here.
1 Like

Glad to see you are interested in AOTInductor. It is in a prototype stage and is being actively developed. AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models — PyTorch main documentation contains an example on how to use it. Please give it a try and let me know if you have any questions.

1 Like

Thanks, that’s clear! So AOTInductor’s entry point is torch._export.aot_compile, and it is part of the export system to export the model as a standalone .so file.

Thanks for the great talk at PyTorch Conf!

Quick question on the interplay between torch.script and AOTInductor – can you explain what the two tests here are testing? I.e., the difference between an AOT-compiled module that is scripted vs. one that is not, benefits / use-cases?

Yes, the test you pointed contains cases where an AOTInductor compiled .so file can be wrapped with torch::CustomClassHolder and thus works with existing TorchScript inference solution. It is a showcase for users who are currently leveraging TorchScript for their inference. If your inference system was not built on top of TorchScript, you can use plain C++ deployment with AOTInductor generated .so file.

Very interesting to follow this! Do I understand it right that the torch::jit::script::Module will be used in the future as class to hold loaded .so files exported by AOT Inductor (even if the name is a bit misleading since no “scripting” is involved anymore)? And will we be able to move the loaded .so files between devices as we were able with the old torchscript workflow?

Calling the .so file from TorchScript is just one way to use AOTI. It is doable as shown in the test, but given that TorchScript is in the maintenance mode, we will not invest in that path.

To your second question, not sure if I understand you correctly, but you can specify a device str, e.g. “cuda:2”, when creating a runner, like here, pytorch/torch/csrc/inductor/aoti_runner/model_container_runner.cpp at 306642b66d9993fdd4ac736af285aa688174b6ce · pytorch/pytorch · GitHub

Thanks for clarification! As I understand now there will be three ways to bring pytorch models to c++:

  1. AOTI export .so file and load via runner (the same .so file can be loaded onto different devices like cpu, cuda or mps if I understand your last post right, am I correct)?
  2. Load .so file with torchscript (maintainance mode).
  3. Capture graph with TorchDynamo and export to torchscript and load torchscript file from c++ as usual (see this comment: What is the recommend serialization format when considering the upcoming pt2?)

Will this work on Windows and on MacOS as well?

  1. The .so file needs to be compiled for specific device. You won’t be able to run a .so for CUDA model on CPU.
  2. It works in theory, but not tested/maintained. This approach still requires TorchDynamo/TorchExport to export a single graph, which is not always possible for every model. If your model does export to a single graph, I don’t see how this choice adds additional value comparing to 1 or 2.

Thanks for your reply! The formatting probably removed the number for your last sentence which I guess refers to 3). I fixed it in my reply above.

The nice thing about torchscript is that we are able to export a single model file from any host and deploy it to any other host, across MacOS, Windows and Ubuntu. This is a very strong feature. Will there be a replacement for this in the AOTInductor world?

Will there be a replacement for this in the AOTInductor world?

No. AOTInductor aims to compile a model to native binary. It doesn’t have a runtime to take some kind of serialized model representation and run it on different devices. However, AOTInductor does support different backends (CPU, CUDA, etc.). I can imaging a higher-level wrapper code can be created to decide which model.so to call into.

I see! And do you know if there is a path towards a device-agnostic export workflow ahead or will there be a point in time when torchscript is finally deprecated and no replacement is available? We had the use case of a heterogenous server-fleet above but what about the case when we ship models to arbitrary users and they run them locally via libtorch? In this case it is impossible to compile all needed variants beforehand.

We will not deprecate TorchScript without a suitable (and technically superior) replacement.

I think the key missing piece, which we are developing but have not yet released, is a generic interpreted runtime that uses libtorch to execute the graph in a target-independent way, optionally calling out to compiled artifacts for acceleration.

So the proposed TorchScript replacement flow would be:

On the frontend:
torch.export → compile subgraphs/whole graph with inductor → packaged model (graph, plus any compiled artifacts)

On the server:
Runtime loads the packaged model and executes it, appropriately selecting interpretation/compiled artifacts depending on the host environment.

Does that picture fit with what you would expect?

1 Like

Hey @suo, this sound exactly like the solution needed. It would be nice making the compiled file optional, as in our case we ship some +5GB torchscript files to Windows, Mac and Ubuntu and serving them with compiled optimized versions would increase the file size 3 times, which won’t be worth the speedup gained.

What would be the best place to track ongoing development on the TorchScript successor and provide input to the developers, both bug reports and feature requests?

Awesome!

What would be the best place to track ongoing development on the TorchScript successor and provide input to the developers, both bug reports and feature requests?

We don’t have the runtime available at the moment, but you can give feedback on the frontend bits (torch.export, the ExportedProgram serialized format). I’ll update this thread when we have something more concrete to share on the runtime/backend side.

1 Like