Lazy Tensor Core

Lazy Tensors in PyTorch is an active area of exploration, and this is a call for community involvement to discuss the requirements, implementation, goals, etc.

We are looking for ways to bring compiler optimizations to a wider range of PyTorch programs than can be easily compiled via torchscript, and provide a better self-service path for accelerator vendors (esp. training) to integrate their own graph-based stacks. We take inspiration from (github / pytorch/xla), which uses lazy tensors to target XLA, and through the work of Alex Suhan and team are prototyping ways to generalize this and see if it can be a useful part of core PyTorch.

Some open areas of discussion right now that weā€™d be interested to hear your perspective on:

  • What is the right user facing API for a core lazy tensor?

** List item

Relevant links:
see this RFC #18 for more details and discussion,
see Alex Suhanā€™s Lazy Tensor prototype and also find links there to the XLA LTC backend.

5 Likes

Any updates?

Since this thread has been surprisingly quiet, letā€™s see if I stir things a bit:

What is the right user facing API for a core lazy tensor?

Ideally, this could be completely transparent to the end user and would enable a seamless JITing story.

Thanks for chiming in!

Seamless is definitely the ā€˜gold standardā€™ for the lazy tensor approach. There are a couple areas where weā€™re still working out how to achieve that. Here are some examples-

  • signaling the end of a training step or a point of dynamic control flow (thus, a good time to break the graph) vs automatic logic to infer this: seems possible to automate, but also quite simple and maybe more reliable to manually annotate
  • dynamic shapes issues- some of the existing pytorch frontend APIs use vectors of ints for shapes. Weā€™re working on ways to capture the ā€˜shape computationsā€™ users might do outside of these APIs, to avoid having static shape constants burned into traces through these APIs

Yes, I can see a number of inconvenient details getting in the way of ā€œcompletely transparentā€. Just to make sure I understand the first point - youā€™re thinking of artificial checkpoints in between the concrete materialization points? What is the motivation? (bounding the size of the expressions which must be compiled or increasing the trace reusability?..)

Regarding the second point, it almost seems that ā€œborrowingā€ from the existing torch.jit.trace would help. On this note, has anyone written down any notes comparing/contrasting lazy tensors with torch.jit.trace?

One limitation I see with lazy/deferred execution is that it pushes starting the work late, potentially hurting performance - especially on hybrid configurations where an accelerator (ex. GPU) can run things in parallel with the host CPU. Any thoughts on this?

I have yet to find the time to explore this, but one idea Iā€™ve been entertaining for a while is a tracing JIT compilation for PyTorch - along the lines of traditional tracing JIT compilation (ex. SpiderMonkey, LuaJit, Dynamo, ā€¦). Lazy tensors could be a good building block, the additional piece would be the ability to detect trace contexts early and launch execution as soon as inputs are ready. This could be done with a mix of tracing jit guards and speculative execution. Is anyone interested, or already working on something like this?

ā€œWeā€™re working on ways to capture the ā€˜shape computationsā€™ users might do outside of these APIs, to avoid having static shape constants burned into traces through these APIsā€

In the lazy tensor target IR (I assume it is TensorScript at the moment?) , will you have a mechanism to make shape computation explicit in TensorScript ?

This is a good point. The tradeoff depends on the scope of the graph that is captured. In some programs (or with some accelerators) it may be preferable to capture small graphs, such as the operations that make up a layer, or even just groups of element-wise + reduction operations. In these cases, you can start work earlier but canā€™t do as much optimization inside the compiler. Others may want to capture whole program graphs, in which case itā€™ll be important to overlap capture (N+1) with execution (N) in the training loop. We can accommodate both of these modes in principle, but havenā€™t committed to a particular user API for exposing the control just yet.

Hmm, what parts of torch.jit.trace do you have in mind to borrow? We have thought about exposing ā€˜ahead of timeā€™ tracing for cases where the model is static and the tracing overhead is significant. There is another body of work where @Chillee is developing new ā€˜eager compilationā€™ tools that are more geared towards training (fwd+bwd graphs) than torch.jit.trace was. Weā€™re working together to see how to best align this with lazy tracing, and possibly share some of the internals.

What I was referring to here is being able to capture pure-python manipulations of tensor shapes. Imagine someone does
shape = my_tensor.size() # returned shape is not a lazy tensor
partial_shape = shape[:2] # trivial math, but not captured by lazy trace
y = torch.empty(partial_shape) # lazy trace of new tensor ā€˜yā€™ treats partial_shape as a ā€˜burned inā€™ constant

Weā€™d like to have a way to trace shape math in user-land too. So, weā€™re thinking about ways to do this - possibly by making the object returned by lazytensor.size() into a ā€˜lazy shapeā€™ object, or something like that.

[quote=ā€œvinodg, post:5, topic:232, full:trueā€]
In the lazy tensor target IR (I assume it is TensorScript at the moment?) , will you have a mechanism to make shape computation explicit in TensorScript ?

Just to be clear on terminology, there are 2 IRs in play- The Lazy Tensor IR, which is really just ATen operations wrapped in node classes- and TorchScript IR, which is the existing IR used by the JIT compiler. We trace in Lazy IR, which is fast to construct and hash, then lower either to XLA HLO (for torch XLA), or convert to TorchScript IR for other backends. The first step to making dynamic shapes work across these backends is to capture the shape computations in the Lazy IR, so we can lower it accordingly.

This is a good point. The tradeoff depends on the scope of the graph that is captured. In some programs (or with some accelerators) it may be preferable to capture small graphs, such as the operations that make up a layer, or even just groups of element-wise + reduction operations. In these cases, you can start work earlier but canā€™t do as much optimization inside the compiler. Others may want to capture whole program graphs, in which case itā€™ll be important to overlap capture (N+1) with execution (N) in the training loop. We can accommodate both of these modes in principle, but havenā€™t committed to a particular user API for exposing the control just yet.

I agree, breaking down the traces to put a bound on JIT time and the deferral latency makes sense. The next step could be detecting early that weā€™re about to execute a path we already traced/jitted & start the execution early - this is what I meant by traditional tracing JIT compilation. What do you think?

Hmm, what parts of torch.jit.trace do you have in mind to borrow? We have thought about exposing ā€˜ahead of timeā€™ tracing for cases where the model is static and the tracing overhead is significant. There is another body of work where @Chillee is developing new ā€˜eager compilationā€™ tools that are more geared towards training (fwd+bwd graphs) than torch.jit.trace was. Weā€™re working together to see how to best align this with lazy tracing, and possibly share some of the internals.

The way Iā€™m looking at this, at very high level, is that both torch.jit.trace and lazy tensors do the same time, in a slightly different way (capture & model a piece of computation). The difference I see today is that lazy tensors trace & execute at the same time, while torch.jit.trace breaks the tracing vs. executing. At the same time, torch.jit.trace (and TorchScript) can represent more generic constructs compare to lazy tensors, including shape computation in this case. Am I mistaken?

One area of convergence I see would be the ability to do transparent torch.jit.trace - trace while execute.

Actually, I donā€™t have a good idea of how to build this. Hand-waving a bit, I could imagine building something more tightly-integrated into CPython, but, I have not been actively looking into this. Iā€™d more generally say that I think it is a logical eventual direction, but probably not one weā€™ll make it to in one step.

So, torch.jit.script is the one that can capture control flow constructs, but, the way it accomplishes this is by using a frontend that only supports a subset of python/pytorch programs and that delta has been a huge burden for the jit team. (Too hard to support it all, and too onerous to use for many users otherwise).

torch.jit.trace, on the other hand, is totally free of syntactical burdens on the user since it traces execution instead of reading the AST of the program. But, it burns in shape assumptions and the resulting program you capture with it canā€™t be (safely) used with input data of differing shapes from the data used to trace. Worryingly, it doesnā€™t even know it is unsafe, and may return a wrong value if used in this way.

1 Like

doesnā€™t torch.fx also have same issues as torch.jit.script in that it needs to be able to handle all of Python AST that could appear in the forward method of nn modules.

How would Lazy IR handle control - isnt that also a tracing compiler?

1 Like

Thatā€™s where I am too - I donā€™t have an answer which doesnā€™t involve deep integration with the Python interpreter. And I donā€™t think anyone is eager to reinvent PyPy and RPython either :slight_smile:

My plan is to keep an eye on Lazy Tensors and if I ever get a chance to explore this further, maybe prototype something. If I get to that point, likely the first building block Iā€™d look into is some form of ā€œauto-tracingā€, a combination of torch.jit.trace and lazy tensors.

1 Like

How do you plan to address the re-compilation challenge in torch-xla? The tensor shape are encoded into vector of integers and made available in Python. For ops with dynamically shaped tensor output, we have no guarantee the users wonā€™t take these Python integers and decide what to do next. For soundnessā€™ sake, we have to truncate and force execution of the LazyTensor IR graph. See https://github.com/pytorch/pytorch/issues/62320 for an example we see when capturing detection models where tensor indexing and nonzero/masked_select are heavily used.

And for control flow: see https://github.com/pytorch/xla/issues/3043#issuecomment-890285855 for a realistic BERT AMP training scenario. The dynamic loss scaling optimizer needs to check for nan/inf in fp16 gradient in Python, stalling the pipeline between tracing step N+1 while executing step N.

I donā€™t want to speak for Will here, so just some personal thoughts.

The ā€œrecompilation challengeā€ is significant, and is arguably the most fundamental challenge in using any kind of trace-based solution (including LazyTensor) in practice. If you donā€™t recompile, then you error out when encountering behavior that canā€™t be traced (a la FX or Jax).

So, addressing recompilation is not easy :slight_smile: . But, what can we do to improve the experience? I think to talk about that, itā€™s useful to think about what causes recompilations, in order from easiest to hardest. The linked doc from Ailing also talks about these challenges.

Of those challenges, I think 1/2.1 (i.e. dynamic shapes without control flow/materialization of sizes) are the easiest to fix, and the one that I think LTC already fixes (simply by virtue of not compiling to XLA lol).

In my opinion, 2.2 is hard to fix, but probably doable with a lot of effort :stuck_out_tongue: . Thereā€™s 2 big problems that need to be solved there. The first is actually capturing what happens, which can be done either through tracing through things in Python (like FX), or returning Tensors for sizes. But even when thatā€™s done, we still need to lower this shape computation into a compiler, which isā€¦ non-trivial.

Problem 3 (control flow) is the most ā€¦ difficult to address (or the easiest, depending on how you think about it). I have a lot of ideas for how to address this, but it seems to difficult to balance UX decisions with actually tracing it in a manner conducive to optimizations.

Regardless, resolving any of these is likely to significantly address the recompilation problem. The other option, to be honest, is simply to rewrite your code in a static manner (which we could also provide the facilities to do). In the 2 examples you provided, I think it makes sense to allow for 1. nonzero to return static shapes (also useful for vmap!), and 2. to provide control flow ops (also useful for vmap!).

TL;DR: Itā€™s a hard problem, weā€™re thinking about it, and thereā€™s a lot of possible avenues. We welcome any suggestions/opinions :slight_smile:

1 Like

Lazy execution has been a dream of mine for several years. Back then I had written a rather efficient Caffe-style framework in Java, and I dreamed about a Torch-style dynamic framework with the same performance. (Also I dreamed that it would be in Java, rather than the much more unwieldy C++.)

Iā€™ve thought about the first question in this thread, of when to submit the AST to execution. I think it should mostly be at the end of a training loop. The reason is that thereā€™s a lot of whole-program optimization opportunities. For example, the model might compute a loss that is then not included in a training run. The entire subgraph could be automatically dropped.

But there a few other reasons.

One, is that I imagined that the backward pass would be written differently than it is now. Now, it is static. It is a built-in process that the framework performs. But in fact, it can be dynamic and framework-agnostic. The optimizer (or optimizers) can request the gradient (or any other function) of whatever parameters it cares about and use it. The graph compiler would then be required to deduplicate the graph and get back to an efficient implementation.

A more fundamental reason is memory. A lot of memory needs to be allocated in the forward pass for use in the backward. To make the best planning decisions, the backward pass has to be in the graph. Moreover, the graph compiler could do automatic checkpointing. (Ie, choose to recompute parts of the forward pass instead of spending memory.)

To be honest, I didnā€™t get around to studying XNA or TorchScript IR. Parts of this discussion seem to assume to offload the work there and not think about it. Iā€™m not yet sure how that would work. This is the case with burned-in vs symbolic tensor shapes. Iā€™m not sure how that affects things. If tensor shapes change, then wouldnā€™t memory allocation have to be replanned no matter what? That is somewhat expensive (or, I suppose arbitrarily expensive). The best thing to do is to track the maximal shapes so that training iterations can execute with a static memory layout. This also applies to computation with some of the more exotic chip architectures.

Hi, dumb question: Where can I find the current status/plan of lazy tensors?

Weā€™re releasing LTC as a ā€˜prototypeā€™ feature in the 1.12 release of PyTorch. It includes the torchscript backend, mostly tested with cuda/nvfuser, and working with most torchbench models but with some failures/bugs. PyTorch/XLA and Cerebras/TorchMLIR are actively building backends on top of LTC.

Additionally, we are working on integration with torchdynamo to provide a lower overhead capture for LTC backends. Dynamic Shape tracing support is an ongoing area of work for both LTC and torchdynamo/AOTAutograd. Distributed training support is also ongoing: we have prototyped tracing of DDP comm ops with LTC but havenā€™t landed it and still exploring the best design to support both LTC and dynamo with multiple distributed frontends.

Hope that helps, let me know if you have more specific questions.

4 Likes

Now that PyTorch 1.12 has been released, how do we go about using LazyTensors?

Can we ie do batch matrix multiply larger than gpu memory so long as the reduction op results in an output less than gpu memory?

Now that PyTorch 1.12 has been released, how do we go about using LazyTensors?
Can we ie do batch matrix multiply larger than gpu memory so long as the reduction op results in an output less than gpu memory?

You can follow the tutorial to get familiar with it, but note that we are not investing much in the torchscript backend as we are targeting torch dynamo for this use case. LTC is being supported mainly for integration with external hw/compiler vendors such as XLA via [PyTorch XLA](GitHub - pytorch/xla: Enabling PyTorch on XLA Devices (e.g. Google TPU).

RE the bmm question, specifically, it would depend on the backend you use whether this optimization is realized but the lazy tensor tracer itself does not allocate intermediate gpu memory until the compiled program is materialized, so it depends on the choices made by the compiler.

Thanks @wconstab.

Do you know of a way to achieve this bmm capability presently, with any stack besides KeOps? Especially interested in a way to realize it with a Torch frontend.