Lazy Tensor Core

Lazy Tensors in PyTorch is an active area of exploration, and this is a call for community involvement to discuss the requirements, implementation, goals, etc.

We are looking for ways to bring compiler optimizations to a wider range of PyTorch programs than can be easily compiled via torchscript, and provide a better self-service path for accelerator vendors (esp. training) to integrate their own graph-based stacks. We take inspiration from (github / pytorch/xla), which uses lazy tensors to target XLA, and through the work of Alex Suhan and team are prototyping ways to generalize this and see if it can be a useful part of core PyTorch.

Some open areas of discussion right now that we’d be interested to hear your perspective on:

  • What is the right user facing API for a core lazy tensor?

** List item

Relevant links:
see this RFC #18 for more details and discussion,
see Alex Suhan’s Lazy Tensor prototype and also find links there to the XLA LTC backend.

2 Likes

Any updates?

Since this thread has been surprisingly quiet, let’s see if I stir things a bit:

What is the right user facing API for a core lazy tensor?

Ideally, this could be completely transparent to the end user and would enable a seamless JITing story.

Thanks for chiming in!

Seamless is definitely the ‘gold standard’ for the lazy tensor approach. There are a couple areas where we’re still working out how to achieve that. Here are some examples-

  • signaling the end of a training step or a point of dynamic control flow (thus, a good time to break the graph) vs automatic logic to infer this: seems possible to automate, but also quite simple and maybe more reliable to manually annotate
  • dynamic shapes issues- some of the existing pytorch frontend APIs use vectors of ints for shapes. We’re working on ways to capture the ‘shape computations’ users might do outside of these APIs, to avoid having static shape constants burned into traces through these APIs

Yes, I can see a number of inconvenient details getting in the way of “completely transparent”. Just to make sure I understand the first point - you’re thinking of artificial checkpoints in between the concrete materialization points? What is the motivation? (bounding the size of the expressions which must be compiled or increasing the trace reusability?..)

Regarding the second point, it almost seems that “borrowing” from the existing torch.jit.trace would help. On this note, has anyone written down any notes comparing/contrasting lazy tensors with torch.jit.trace?

One limitation I see with lazy/deferred execution is that it pushes starting the work late, potentially hurting performance - especially on hybrid configurations where an accelerator (ex. GPU) can run things in parallel with the host CPU. Any thoughts on this?

I have yet to find the time to explore this, but one idea I’ve been entertaining for a while is a tracing JIT compilation for PyTorch - along the lines of traditional tracing JIT compilation (ex. SpiderMonkey, LuaJit, Dynamo, …). Lazy tensors could be a good building block, the additional piece would be the ability to detect trace contexts early and launch execution as soon as inputs are ready. This could be done with a mix of tracing jit guards and speculative execution. Is anyone interested, or already working on something like this?

“We’re working on ways to capture the ‘shape computations’ users might do outside of these APIs, to avoid having static shape constants burned into traces through these APIs”

In the lazy tensor target IR (I assume it is TensorScript at the moment?) , will you have a mechanism to make shape computation explicit in TensorScript ?

This is a good point. The tradeoff depends on the scope of the graph that is captured. In some programs (or with some accelerators) it may be preferable to capture small graphs, such as the operations that make up a layer, or even just groups of element-wise + reduction operations. In these cases, you can start work earlier but can’t do as much optimization inside the compiler. Others may want to capture whole program graphs, in which case it’ll be important to overlap capture (N+1) with execution (N) in the training loop. We can accommodate both of these modes in principle, but haven’t committed to a particular user API for exposing the control just yet.

Hmm, what parts of torch.jit.trace do you have in mind to borrow? We have thought about exposing ‘ahead of time’ tracing for cases where the model is static and the tracing overhead is significant. There is another body of work where @Chillee is developing new ‘eager compilation’ tools that are more geared towards training (fwd+bwd graphs) than torch.jit.trace was. We’re working together to see how to best align this with lazy tracing, and possibly share some of the internals.

What I was referring to here is being able to capture pure-python manipulations of tensor shapes. Imagine someone does
shape = my_tensor.size() # returned shape is not a lazy tensor
partial_shape = shape[:2] # trivial math, but not captured by lazy trace
y = torch.empty(partial_shape) # lazy trace of new tensor ‘y’ treats partial_shape as a ‘burned in’ constant

We’d like to have a way to trace shape math in user-land too. So, we’re thinking about ways to do this - possibly by making the object returned by lazytensor.size() into a ‘lazy shape’ object, or something like that.

[quote=“vinodg, post:5, topic:232, full:true”]
In the lazy tensor target IR (I assume it is TensorScript at the moment?) , will you have a mechanism to make shape computation explicit in TensorScript ?

Just to be clear on terminology, there are 2 IRs in play- The Lazy Tensor IR, which is really just ATen operations wrapped in node classes- and TorchScript IR, which is the existing IR used by the JIT compiler. We trace in Lazy IR, which is fast to construct and hash, then lower either to XLA HLO (for torch XLA), or convert to TorchScript IR for other backends. The first step to making dynamic shapes work across these backends is to capture the shape computations in the Lazy IR, so we can lower it accordingly.

This is a good point. The tradeoff depends on the scope of the graph that is captured. In some programs (or with some accelerators) it may be preferable to capture small graphs, such as the operations that make up a layer, or even just groups of element-wise + reduction operations. In these cases, you can start work earlier but can’t do as much optimization inside the compiler. Others may want to capture whole program graphs, in which case it’ll be important to overlap capture (N+1) with execution (N) in the training loop. We can accommodate both of these modes in principle, but haven’t committed to a particular user API for exposing the control just yet.

I agree, breaking down the traces to put a bound on JIT time and the deferral latency makes sense. The next step could be detecting early that we’re about to execute a path we already traced/jitted & start the execution early - this is what I meant by traditional tracing JIT compilation. What do you think?

Hmm, what parts of torch.jit.trace do you have in mind to borrow? We have thought about exposing ‘ahead of time’ tracing for cases where the model is static and the tracing overhead is significant. There is another body of work where @Chillee is developing new ‘eager compilation’ tools that are more geared towards training (fwd+bwd graphs) than torch.jit.trace was. We’re working together to see how to best align this with lazy tracing, and possibly share some of the internals.

The way I’m looking at this, at very high level, is that both torch.jit.trace and lazy tensors do the same time, in a slightly different way (capture & model a piece of computation). The difference I see today is that lazy tensors trace & execute at the same time, while torch.jit.trace breaks the tracing vs. executing. At the same time, torch.jit.trace (and TorchScript) can represent more generic constructs compare to lazy tensors, including shape computation in this case. Am I mistaken?

One area of convergence I see would be the ability to do transparent torch.jit.trace - trace while execute.

Actually, I don’t have a good idea of how to build this. Hand-waving a bit, I could imagine building something more tightly-integrated into CPython, but, I have not been actively looking into this. I’d more generally say that I think it is a logical eventual direction, but probably not one we’ll make it to in one step.

So, torch.jit.script is the one that can capture control flow constructs, but, the way it accomplishes this is by using a frontend that only supports a subset of python/pytorch programs and that delta has been a huge burden for the jit team. (Too hard to support it all, and too onerous to use for many users otherwise).

torch.jit.trace, on the other hand, is totally free of syntactical burdens on the user since it traces execution instead of reading the AST of the program. But, it burns in shape assumptions and the resulting program you capture with it can’t be (safely) used with input data of differing shapes from the data used to trace. Worryingly, it doesn’t even know it is unsafe, and may return a wrong value if used in this way.

1 Like

doesn’t torch.fx also have same issues as torch.jit.script in that it needs to be able to handle all of Python AST that could appear in the forward method of nn modules.

How would Lazy IR handle control - isnt that also a tracing compiler?

1 Like

That’s where I am too - I don’t have an answer which doesn’t involve deep integration with the Python interpreter. And I don’t think anyone is eager to reinvent PyPy and RPython either :slight_smile:

My plan is to keep an eye on Lazy Tensors and if I ever get a chance to explore this further, maybe prototype something. If I get to that point, likely the first building block I’d look into is some form of “auto-tracing”, a combination of torch.jit.trace and lazy tensors.

1 Like

How do you plan to address the re-compilation challenge in torch-xla? The tensor shape are encoded into vector of integers and made available in Python. For ops with dynamically shaped tensor output, we have no guarantee the users won’t take these Python integers and decide what to do next. For soundness’ sake, we have to truncate and force execution of the LazyTensor IR graph. See https://github.com/pytorch/pytorch/issues/62320 for an example we see when capturing detection models where tensor indexing and nonzero/masked_select are heavily used.

And for control flow: see https://github.com/pytorch/xla/issues/3043#issuecomment-890285855 for a realistic BERT AMP training scenario. The dynamic loss scaling optimizer needs to check for nan/inf in fp16 gradient in Python, stalling the pipeline between tracing step N+1 while executing step N.

I don’t want to speak for Will here, so just some personal thoughts.

The “recompilation challenge” is significant, and is arguably the most fundamental challenge in using any kind of trace-based solution (including LazyTensor) in practice. If you don’t recompile, then you error out when encountering behavior that can’t be traced (a la FX or Jax).

So, addressing recompilation is not easy :slight_smile: . But, what can we do to improve the experience? I think to talk about that, it’s useful to think about what causes recompilations, in order from easiest to hardest. The linked doc from Ailing also talks about these challenges.

Of those challenges, I think 1/2.1 (i.e. dynamic shapes without control flow/materialization of sizes) are the easiest to fix, and the one that I think LTC already fixes (simply by virtue of not compiling to XLA lol).

In my opinion, 2.2 is hard to fix, but probably doable with a lot of effort :stuck_out_tongue: . There’s 2 big problems that need to be solved there. The first is actually capturing what happens, which can be done either through tracing through things in Python (like FX), or returning Tensors for sizes. But even when that’s done, we still need to lower this shape computation into a compiler, which is… non-trivial.

Problem 3 (control flow) is the most … difficult to address (or the easiest, depending on how you think about it). I have a lot of ideas for how to address this, but it seems to difficult to balance UX decisions with actually tracing it in a manner conducive to optimizations.

Regardless, resolving any of these is likely to significantly address the recompilation problem. The other option, to be honest, is simply to rewrite your code in a static manner (which we could also provide the facilities to do). In the 2 examples you provided, I think it makes sense to allow for 1. nonzero to return static shapes (also useful for vmap!), and 2. to provide control flow ops (also useful for vmap!).

TL;DR: It’s a hard problem, we’re thinking about it, and there’s a lot of possible avenues. We welcome any suggestions/opinions :slight_smile: