State of symbolic shapes branch

State of PT2: Mar 2, 2024 edition

  • Jeffrey and Nested Int (https://docs.google.com/document/d/1-CR1hP1d-rJgQzCnIRLrwfpm4RbGhXYCesCSnnHZ6r0/edit?usp=sharing). The design process has definitely been a bit of a slog, but we had a lot of forward progress: (1) our compromise between “cpu and cuda nested ints with equivalent data should compare equal” and “nested int API should not have implicit syncs or expensive compute” is that we will report true/false only if we know, without syncing, that two tensors are equal/non-equal, and otherwise we will raise an error (and then there’s an API for explicitly doing compute / trust me these are equal), (2) we should have a separate subsystem responsible for doing this on immutable tensors which is independent from nested ints, (3) we do need a concept of symbolic variable that ranges over nested ints. Some other aspects of the original PR can be replaced with more memoization of nested ints on fake tensor. According to Jeffrey: “put something together that passes preexisting NT tests and its looking a lot cleaner.”
  • Horace, Brian, Richard and me talked about Retracing Backward (Retrace the backward in AOTAutograd - Google Docs). This is the problem where AOTAutograd wants to trace backwards of a graph ahead of time, but it doesn’t know what memory layout / subclass the tangents will be ahead of time. In the end, we came up with three main proposals. (1) Easy to implement solution: the user just somehow tells us the layout of the tangents, we use it in AOTAutograd, everything stays the same. The main problem is UI for how to actually do this–a manual API is difficult because not obvious what outputs are (graph breaks). Could imagine some sort of machine-only format like PGO. (2) Trace graph post-autograd but pre-subclass, and run the partitioner on this class before running subclass. Kind of invasive: although Horace proposed it, he doesn’t really like it. (3) Trace tangent-dependent backwards graph pre-dispatch, desugar everything else. This proposal relies on a key property which is that the partitioner isn’t actually partitioning the entire graph: we can think of the graph as consisting of three parts: the compute for forward outputs, the compute that depends on tangents, and a no-man’s land of intermediate compute. The partitioner chooses what forward compute to recompute, and whether or not to put intermediate compute in forward or backward, but cannot touch the compute that depends on tangents. In the absence of TorchDispatchMode, only the tangent-dependent compute can meaningfully change based on tangents. So you can still run the partitioner in full fidelity, and you just need to arrange for the tangents-dependent portion of the graph to be untraced (e.g., pre-dispatch tracing) to the degree that you can retrace it later when the tangents show up. This will be pretty difficult to implement, but we agree conceptually it should work, and gives the best UX.
  • Avik and I knocked heads about what to do about everyone continuously stubbing their toe on guard on data-dependent shapes errors (psst, read Dealing with GuardOnDataDependentSymNode errors - Google Docs). Avik’s big idea is that it would be a lot easier to fix model for errors if we could closely correlate symbols to actual variable names in source code. This is not so bad in Dynamo, but it’s bit of a lift in non-strict export since we don’t have enough hooks in Python to engineer this without a lot of black magic. But we could make it easier for people to probe what size variable any given source variable corresponds to. We talked about some recent problems people ran into: sometimes people are running into lazily evaluated expressions, my opinion is that this was a mistake and we should just guard instead; we also were able to follow up on some of the problems with some fixes.
  • Brian’s been busy!
    • In DTensor land, people are worrying about communication reordering passes, like changing a collective-select into select-collective. Must guarantee every node does the same thing. Graphs on different ranks are not guaranteed to be the same! Design is needed here.
    • Even in per-parameter FSDP (aka FSDP2), we are still using the storage resize to zero trick. Jason Ansel is adding storage resizing as a native concept to inductor, so it can be run as early as possible. The resize seems to be necessary, esp in partial graph cases, so gonna have to deal with it.
    • FSDP2 is facing an interesting input aliasing problem. We receive the sharded, unsharded and unsharded parameters all as arguments, and unsharded aliases unsharded padded (via an as_strided call). Then, when we copy_ the result of allgather into the unsharded view, this results in a as strided scatter (because it writes into the unsharded padded buffer). But if these are contiguous, we ought to be a lot faster.
  • Composability sync https://www.youtube.com/watch?v=ACR1WnRScCc
    • User defined Triton v custom ops: the resolution was that we will “release” user defined Triton kernels in PT2, but we will encourage people to go straight to custom ops API when writing new code. (However, Richard is also investigating how to get rid of some of the main pain points have about working with custom ops, more below.)
    • C++ FX: not gonna happen, we’ll probably do other stuff.
  • Down the grapevine from Michael Suo: we’re probably going to care about non-Intel CPU performance in PT2 soon… Meta only: https://docs.google.com/document/d/1m_aQaMFF6T62z2kH8yaS9SIuVijHUR7vNn6DutSU1Xo
  • Greg (channeling Soumith) asked me a big question that is worth thinking about in 2024: how can we make PyTorch less complex? (Alternately, how can we make development in PyTorch go faster?)
  • Richard has been trying to figure out how to get rid of the schema requirement for custom ops, and he pitched me a pretty interesting idea: what if we FX trace the inner implementation of an operator, stubbing out the (simpler) bits where we actually write to data pointers, and then use that graph to get schema? You can potentially use this to get meta for free (since you’ve just “commented” out the actual compute.)
  • Sparse arch enablement getting to the point where we need a spreadsheet of doom. Edward to organize at least the first version of it. A snapshot of the enablement issues recently: unbacked SymInts related to guard size oblivious, a Dynamo list append regression, adjust_info_B_num_bits shenanigans.
  • Flavio’s baby is coming soon, I’ve been encouraging him to focus less on E2E enablement, and just landing stuff to main that we know is needed.
  • Some communication debugging with the new PL&R team: focus on time to first batch and megawatts saved metrics; metrics as a means to an end of improving compile time. Order of operations doesn’t matter too much, do the right thing. Cross-org comms is hard.
  • Animesh close to done with C++ guards. Not a clear way to measure the benefit. Some discussion about a new problem where because we don’t inline builtin torch nn modules, we guard on their ID, which leads to spurious recompiling when torch.compile is on smaller pieces of model rather than whole model. Best solution is to trace through (this helps Jack Cao too.)
  • Export is speccing out a “torch function” granularity operator stack info, similar to the nn module stack info, which you can use to unflatten a linear call. In non-strict export can be implemented as TorchFunctionMode, but needs Dynamo to understand where torch function interposition points are.
  • Structured logs landed! TORCH_TRACE and use tlparse to parse the result into an HTML report. Still MVP, lots of improvement possible.
  • Notable new bugs in dynamic shapes:
  • Landed stuff
1 Like

Are the design docs presented during Pytorch composability syncs available? Mainly interested in the PT2 dequant docs that @Chillee presented during the last meeting.

How can I learn more about the bullet

  • Export is speccing out a “torch function” granularity operator stack info, similar to the nn module stack info, which you can use to unflatten a linear call. In non-strict export can be implemented as TorchFunctionMode, but needs Dynamo to understand where torch function interposition points are.

@jeromeku Meeting notes are at Composability meeting notes - Google Docs ; sometimes people forget to make the docs public though. (Do you mean last meeting though? Maybe you’re referring to Feb 8?)

@thiagocrepaldi It’s this doc Rationalizing nn_module_stack and source_fn_stack - Google Docs

1 Like

@ezyang @Chillee
Thanks! Found the aten.dequant doc in the meeting notes but it is not public.

State of PT2: Mar 7, 2024

Compiler and distributed offsite!

  • Benchmarking: Difficult to identify regressions that affect only a single regression, as oncalls only look at the accumulated statistics. Generally, Horace doesn’t trust our benchmarks. Some ideas for making it better: measure only the CUDA kernel times to reduce variance (good for non overhead bound benchmarks), add synthetic benchmarks to test overhead, to test overhead of specific subsystems, run them over and over in a loop microbenchmark style (it was important to jansel that we get a representative set of guards over all the benchmarks though), compute the hypothetical peak speed and compare against that instead of eager (best for operator level, or hard code a specific number to compare against). A bit difficult to prototype new HUD ideas, because no one is actually working on HUD.
  • Dynamo warm start: goal here is to deal with some use cases like (1) your training job was killed and you want to restart it without having to recompile, (2) you have 10k nodes and you don’t want to compile 10k times, (3) you have a custom operator that you want to run fast with PT2, (4) you are working on a model locally and you don’t want to wait to compile every time. Here’s the plan:
    • First, we need to make the compiled output of AOTAutograd serializable (maybe Sam Larsen will do this, it’s a logical follow up from Inductor serializable), this gives you run Dynamo all the time (make this faster!) and then cache hit AOTAutograd, 100% correct.
    • Next, we add a YOLO cache option, where we assume you didn’t change any source code and we only test a limited set of easy to serialize guards (probably shape guards) to test if a cache entry is OK. You have to explicitly ask to load from the cache, and you’re expected to manually do cache invalidation yourself. Need to implement side effects like installing globals from Dynamo.
    • Finally (and we can choose not to do this in the end), we can try to do robust always on Dynamo caching. Combination of (1) build-system style testing if Python files Dynamo traced over have changed by hashing them, (2) Dynamo guard serialization (e.g., need to deal with ID_MATCH).
  • We had a User Empathy Day, where a bunch of us took popular OSS models and tried to torch.compile them. For some models, they worked and had speedups, but a lot of models failed. We filed a lot of bugs. Some more details: User Empathy Day - Google Docs
  • Some discussion about DTensor. Pretty interesting work from the external community: [DTensor] Open Source by leonardo0lyj · Pull Request #8 · volcengine/veScale · GitHub Some interest in moving the main class of DTensor to C++ to make it possible to use in eager mode without PT2. Wanchao is interested in MoE use case for DTensor where you have irregular sharding; the theory is that this should be a subclass of the sharding spec, but this needs more investigation.
  • Horace has a proposal called “hierarchical compilation”. The motivating problem is when we have a loop over a basic block, we inline everything and uselessly recompile the same block over and over again. Instead, we want to compile it once and reuse it for all the calls. What makes this difficult is there might be Python state update, e.g., updating a loop counter or appending to a list, which you still need to apply in the outer loop. So intuitively, the idea is to recursively invoke Dynamo on the inner block, producing an opaque graph and some residual bytecode. Then, in the outer Dynamo session, you trace through this graph and bytecode however many times in the loop to update your state accordingly, and you keep going. The primary implementation complication is that at the time you do the inner Dynamo session, you need to directly use the outer Dynamo’s VariableTracker state, because there may have been updates to the Python state that you need to see, while at the same time not actually applying your inner updates to the outer variables until the very end.
  • Francisco presented simple FSDP, a very simple implementation of FSDP using only a custom autograd function for allgather, and a parametrization to allgather parameters before they are used. In combination with selective activation checkpointing and a small 150 LOC FSDP specific optimization pass, they achieve memory usage and performance on parity with the complicated eager implementation. Assuming that you can deal with the compiler stack, this is a great starting point for more complicated userland ideas for scaling.
  • This spawned a discussion about a bigger tension that showed up in several contexts: hedging eager vs all-in compile. Right now, we have a very compromise strategy, where we are working on torch.compile, but very much as something you put on top of existing working eager code. Francisco shows us that if you give up on eager mode performance and go all in PT2, this is a really interesting point on the pareto curve. But on the other hand, we have projects like DTensor and FP8 where the folks we are engaging with cannot do compile, and the eager perf needs to be good and that means we need to write things in C++. I think we’re still going to be doing the compromise strategy, but it is definitely worth continuing to ramp up docs and understanding about “green field PT2”, as it can be the right choice when there is a champion for PT2 on a project that really benefits from a compiler.
2 Likes

State of PT2: Mar 16, 2024

Core offsite!

1 Like

State of PT2: Mar 24, 2024

Back to work after two weeks of traveling.

  • We’re going to let PyTorch developers add calls to OpenTelemetry C++ API from inside the PyTorch codebase, to make it easier to add library level instrumentation that you might be interested in when running large training jobs. By default, OSS distribution of PyTorch will no-op all of these telemetry calls, but if you rebuild PyTorch from source you can point them to your favorite observability platform. We’ll also make calls from Python possible once we figure out the dependency situation. Meta only: Redirecting...
  • The AO folks were asking about NF4 dtype. NF4 is a quantization scheme with 16 values which are normally rather than uniformly distributed around zero. The ask was whether or not it should get a dedicated dtype (like torch.bfloat16) in core. Unlike FP8, it doesn’t have direct support in silicon, so it fails the test for core support put out in Supporting new dtypes in PyTorch - Google Docs . However, @cpuhrsch pointed out that not having an actual dtype for the type poses some UX problems for a tensor subclass implementing it. In particular, what does the dtype field of this subclass report? Right now (modeled off of FP8), this dtype field reports what floating point type the quantization scheme is seeking to simulate; typically bfloat16, EVEN though it’s actually fp8/nf4 under the hood. But this is very awkward: if I say fp8_tensor.to(torch.bfloat16), this will no-op, because Tensor.to looks at the dtype field before deciding to do the conversion and thinks “Oh yes, this is already bfloat16 (because that’s what it was lying about)” and does nothing. An added complication to the problem is in autograd, where we expect precision of gradients to match precision of primals, to the point where autograd will automatically insert conversions to make things match up. So in fact, for a class like FP8Tensor, you want to lie that your dtype is the dtype of your gradient so the autograd engine doesn’t reduce the precision of your gradients (which must typically be in higher precision than FP8). You could imagine an extremely BC-breaking change to the autograd engine to not require dtype matching, but this doesn’t absolve you of the problem of determining what the precision of your gradients should be… and knowing what the precision of the primals were is very useful information.
  • We discussed hierarchical compilation at composability sync. Check out the very minimal design doc at Hierarchical compilation - Google Docs @anijain2305 and @laitho90 are working on phase 1.
  • William Wen talked about progress to Python 3.12 support at Dynamo team meeting. It seems changes to the eval frame API are the biggest blockers right now. Actually, once we finish 3.12, it’s time to immediately start working on 3.13, which is going beta soon
  • We’ve been working on getting a grip on all of the issues in the GitHub issues tracker. Apologies for any notification spam (and rip my inbox LOL.)
  • Jack Cao’s still been working on inlining into NN modules for tracing backwards single graph, and one thing he’s noticed is that quantization is relying on the NN module structure from Dynamo, so flattening it away is causing those tests to fail. Our current thinking is to keep the legacy behavior unless see a backward() call, in which case we keep inlining, although it would be nice to go further and switch the behavior unless you explicitly opt into the legacy behavior… need to talk to the AO folks some more about this…
  • Michael Suo is being sent to llama land, wish him the best of luck :wink: . Horace has also been thinking about how to better grapple with large scale training. My personal take on the matter is that you probably want to do specific, purpose-built infrastructure at this scale, and then take the lessons to make generalized infrastructure adjustments.
  • AOTInductor lightweight wrapper from Python is collecting user requirements https://docs.google.com/document/d/1tP_7InSSKQ1zW1HDc2W1juJzIFhlmTPHqcTYLdgpWh0/edit although apparently Purpleberry is going to release something much simpler in the near future.
  • There’s a lot of interest in AOTAutograd recently, which might soon be an air traffic control problem for this part of the code. Brian still has a stack of DTensor fixes that will be landed soon. Todd Fiala is going to look into reducing fixed overheads from AOTAutograd prompted by this post from Jason AOTAutograd has high fixed overheads · Issue #122029 · pytorch/pytorch · GitHub and we also need to implement AOTAutograd level caching to get the next layer of caching on top of Inductor. There’s also interest in getting more hands on AOTAutograd.
  • I spent a chunk of the week working on the plan at Factor meta conversion into real tensor -> serializable metadata -> fake tensor; create fresh ShapeEnv when retracing · Issue #121085 · pytorch/pytorch · GitHub and while I completed the refactor, I am a lot more bearish on the idea of using this to solve problem (3). It’s just a huge change to start having separate ShapeEnv at each layer of the stack and it doesn’t even solve half of the reallocation problem (fake tensor repropagation in Inductor). I dusted off Preserve unbacked SymInt on SymNode by ezyang · Pull Request #120816 · pytorch/pytorch · GitHub to directly fix the thing that directly induced this, the “direct” fix seems to work OK now.
  • torchrec reported a major milestone in PT2 compitability. Meta only: Redirecting... The way I’d characterize this work stream is that the “right” kinds of fixes are being landed, we are emerging into the AOTAutograd/Inductor neck of the enablement work, nothing is happening on JT/NJT land, and we are also getting to distributed land (e.g., Ivan ran into a problem with collectives that turned out to be wait being pointed to the wrong implementation.) Yifu, when are we unifying the functional collectives?!
  • Notable new bugs:
  • Notable fixes:
1 Like

State of PT2: Mar 29, 2024

I am on vacation next week, so you get this report early!

1 Like

State of PT2: Apr 13, 2024

I’m feeling lazy so short report today.

1 Like

State of PT2: Apr 20, 2024

1 Like

State of PT2: Apr 28, 2024

  • We had another meeting with the OpenTelemetry crew to figure out what to do next. As written, the integration seems to be good for llama training (and we got some code pointers for some of the existing logging, see Meta only: https://docs.google.com/document/d/1qL7BwL5uK_AS8IqDTgULWiLwG2ssN8_RQsVndDMrqBY/edit ). It seems to be less good for more aggregate counters: James March related to us an abstraction called “wait counter” which doesn’t exist in OpenTelemetry’s API. Wait counter addresses the following use case: suppose you want a counter that ticks up while you are blocking on something. The naive way to implement this is to save the start and end time of the blocking operation, and bump the counter with the delta when the operation ends. But if the operation hangs, you will never find out about the time spent blocking, since the end event never occurred. The wait counter maintains a separate thread which continuously ticks up the counter while you are waiting, so this doesn’t happen. James is concerned that we have a lot of pretty good abstractions in our internal stack already, and just doing OpenTel may lead the way for people just adding a lot of bad instrumentation. Something to keep an eye on.
  • FlexAttention by Driss and Horace is on main! Flex attention lets you write custom attention functions with customized pieces like the score mod, without having to write Triton from scratch. PR: ScoreMod API by drisspg · Pull Request #121845 · pytorch/pytorch · GitHub
  • Natalia was telling me about an interesting problem where people want to customize the “gradient accumulation” operation that occurs when you implicitly reuse a variable multiple times. It’s similar to how in a linear type system you would have to “dup” a value to use it multiple times, and then you would customize the backwards of the dup.
  • Animesh was telling me about how we’re making good progress on C++ guards: specifically, with C++ guards, we are beating guard performance for what we have today, even when NN module guards are enabled (which we currently don’t enable, because they are too slow). So this is paving the way for us enabling NN module guards by default. Animesh also pointed out to me that we don’t need Jack Cao’s inline through NN modules logic, because we can just start inlining through NN modules and so we only need to know how to inline through torch. Animesh was also telling me about how we’re doing a bad job keeping track of internal bugs and issues, perhaps it is a good time to setup a task tracking system for the Q&A group.
  • I spent some time with Brian discussing exactly how we should be handling dynamic sizes and tensor subclasses, related to [test fix] try to re-use cached symnodes across dynamo and AOTAutograd by bdhirsh · Pull Request #120090 · pytorch/pytorch · GitHub . The general shape of the problem is that although all symbolic sizes are explicitly passed as inputs to the FX graph, when AOTAutograd creates a subclass on the fly on the inside of the traced function, it doesn’t necessarily reuse the SymNodes, and this results in “don’t have proxy for SymNode”. So the correct fix appears to be to ensure that when processing subclasses, we must also generate extra inputs as necessary in the input calling convention. Any sizes that are needed to construct the outputs should be int outputs of the compiled graph.
  • We hit another milestone this week: on Flavio’s branch, we can successfully run inductor with forwards and backwards on Redirecting... This is thanks to fixing a number of Inductor codegen bugs, as well as some small improvements for backwards support specifically related to unbacked SymInts. I spent some time talking to Ivan Kobzarev about what’s next. It looks like low hanging fruit for symbolic shapes optimizations are all gone, so we are going to have to do some more in depth analysis. There are a few newer problems which Ivan will post to the WG. There’s still some stuff to work out regarding streams, although the current plan is to use simple pipeline to avoid needing to deal with this for now. There are some new paths that want to be enabled but are also suffering from GuardOnDataDependentSymNode problems, Add propagate_real_tensors mode for unbacked by ezyang · Pull Request #125115 · pytorch/pytorch · GitHub to help on this.
  • Alban related to me that MPS wants to support not actually doing anything to tensor data when you say to(‘mps’), since unified memory, but this violates semantics as to is specified to return a fresh tensor. This is a good use for copy on write tensors, which Kurt has been continuously working on.
  • Notable new bugs:
  • Notable fixes:
1 Like

State of PT2: May 4, 2024

  • ASPLOS was this week. According to those who went, the PT2 tutorial was well received. One of the spicier moments was when Jason Ansel had a call for improved benchmarking in academic research… and then the next talk had a bunch of questionable evaluations lol.
  • We had a minor breakthrough regarding all the flaky tests in PyTorch: actually, a lot of them are not actually flaky, they just depend on test ordering. You can use GitHub - asottile/detect-test-pollution: a tool to detect test pollution breakthrough to create minimal repros when this happens! We’ve already produced minimal repros for two bugs this way.
  • The OpenTelemetry saga just doesn’t want to die LOL. Rajesh Nishtala from AI Infra Training is looking around in our corner of the woods to see how he can help. James March has concluded that he’d much prefer it if we weren’t using the OpenTelemetry APIs directly, but had some indirection in front of it which we could route to our internal infra directly. Chirag and I spent some time comparing OpenTelemetry and fb303 APIs, and the big delta is that modern fb303 is very macro-based to avoid repeated string parsing (whereas OpenTelemetry is much more virtualized / willing to tolerant indirections for ergonomics.) James has enlisted Victoria and Andrii to come up with an API; we’re interested in knowing how important it is for the API to be macro-based. I plan on paying some more attention to this.
  • Some stuff I learned from 1:1s:
    • Horace feels he understand Spark now. The innovation of Spark over traditional Hadoop is that in Hadoop, you had to put everything back to disk after every step. Spark lets you keep it in memory, but now you have a fault tolerance problem, which you solve with provenance. Horace is thinking that ubiquitous provenance is a good primitive for systems to offer researchers, and then researchers can build their parallelization strategies on it. A big crux of the problem is finding a good split between system and research, so that systems team have a leverage point, while researchers can figure out what they need.
    • James update on AOTAutograd caching, as told by Brian: big finding is that it’s not so easy to cache the FX graph produced by AOTAutograd, because it contains stuff in the meta dict that is not easily serializable. There’s some work on fake tensor serialization that could help here (at minimum, you must have accurate alias information, as Inductor depends on it; also, symbolic shapes will be difficult), but my personal perspective is it’s better to go straight to consolidated cache for everything. We’ll see what James ends up doing.
  • Unbacked SymInts!
    • What I’m most worried about this week is Ivan’s latest set of issues involving repeatedly re-viewing tensors between 1d and 2d, for non-variable batch codepath. They seem to be not so easy to resolve, in part because it’s hard to crisply define what exactly needs to work and what doesn’t. Symbolic shapes unable to reason: Ne(Mod(u0*u2 + u1*u2, u0 + u1), 0) · Issue #125307 · pytorch/pytorch · GitHub
    • Ivan’s other issue is a stack overflow that happens when you have too many chained operations on ints Redirecting... . The root cause of this is that we trace int operations into the FX graph lazily, but this means you can end up with an arbitrarily large chain of thunks that then blows your stack. The current idea for fixing this is to only do lazy tracing inside meta functions; for regular tracing, don’t do it. Ivan to see if this works.
    • Greg had a really good idea, so now I’m working on some data dependent shapes puzzlers. The idea is they’ll be designed to help you speedrun the last year of data dependent learnings we’ve done and quickly get up to speed. I’ve only written two so far though, busy busy.
    • Shaz Qadeer is exploring the data dependent shapes area. So expect some stuff from him soon!
    • Propagate real tensor is here Add propagate_real_tensors mode for unbacked by ezyang · Pull Request #125115 · pytorch/pytorch · GitHub
    • We discussed unbacked SymInts at PT2 export again, mostly rehashing some points about proper usage of propagate real tensor, and using locals of frames to give better information about what size variables mean. Avik had a really good suggestion that propagate real tensor can generate deferred runtime asserts, I should add this.
  • There was an interesting SEV that occurred when someone fixed a meta function, causing a large amount of graph breaks. The root cause was that the newly added meta function forced specialization, so the frame in question started getting recompiled and hit the cache limit, and this caused cascading problems for inner frames.
  • tlparse 0.3.17 released, the big change is color coded compile ids in the stack trie. Meta only: Redirecting...
  • It turns out we spend a lot of time generating stack trace strings in Meta prod for exceptions that don’t end up showing those stack traces to users. Some patches to make this generation lazy landed. It was more involved than it seemed because it’s important to ensure that the lazy generation process is thread safe.
  • Notable bugs:
  • Notable fixes:
2 Likes

State of PT2: May 12, 202

  • Data dependent shapes
  • Horace was recently thinking about whether or not we really need data dependent shapes. This was prompted by some discussions he’d been having with some researchers who were doing some MoE style stuff with DtoH syncs (aka some sort of data dependent routing)… and it was just too expensive and they were eventually going to end up just rewriting their model so that it didn’t need those syncs. We talked about this a few times over the week.
    • First, Horace was curious whether or not data dependent shapes was really needed in the recommender system use case. Here, the dynamism arises from sparsity in features: we are given some inputs and we need to partition them for model parallelism, but there may be imbalance in how many features we send to each partition. Horace was curious whether or not we knew some sort of maximum partition size, so we could simply pad out or employ some sort of device-side sparsity on the other end. Dennis said that yes, in principle, you could OOM due to imbalance, but it never actually happens in practice, and people don’t want to pad, it uses more memory and makes it slower. The extra memory is not technically a problem, but people evaluate modeling changes against memory usage (even if there is headroom), so they will get annoyed at you if you unnecessarily increase the memory usage of your model. Additionally, truncation is a big no no, since that means numerical difference and that means you have to do full on evaluation (as opposed to numerically identical optimizations). And yes, the D2H syncs aka comms are the main issue, but this is sort of understood and minimized to a single all2all communicating splits. And we know the comms are very exposed and a lot of the optimization about sparse arch is making sure you aren’t screwed over by comms.
    • So then we talked a bit about block sparsity, which is the most recent tool Horace has been looking to use in situations where you wanted data dependent symbolic shapes. Horace brought up too: block sparse attention, and block sparse to implement MoE ala MegaBlocks (https://arxiv.org/pdf/2211.15841). There is also some data-dependent modeling going on here too: you can implement MoE by performing a boolean mask to extract out the tokens that should go to a particular expert, but this incurs a DtoH sync. Instead, you can avoid syncs altogether by permuting the input tokens, grouping the tokens that should go to various experts together. MegaBlocks shows you can do this without dropping or padding, simply by designing a block sparse matrix multiply kernel that can take a permutation with uneven assignments (padded up to some reasonable multiple of tiling) and maintain occupancy. Now, there is no data dependent shapes at all: simply an irregularly sparse data structure and some custom kernels. So what does this mean for nonzero/boolean masking? My intuition here is that folks are always welcome to write their models in a careful block sparse way, but for the rest of us, we are hoping for some sort of higher level UX which can compile down to this. What does this UX look like? At least for the MoE case, it seems like some sort of vmap on a jagged data structure (where the jaggedness represents imbalance) could be one possible way to represent this.
  • Did you know that non_blocking device transfer doesn’t work in Inductor? We just… don’t do anything about it. It should likely be handled similarly to communication ops.
  • Jason Ansel brought up a spicy topic: is functionalization more trouble than its worth? This was precipitated by set_ functionalization taking a long time. In particular, Jason pointed out that Inductor largely does understand how to deal with mutation. We discussed this a bit: it’s probably still worth functionalizing most things by default, but for a certain limited subset of things, we may want to consider permitting mutating operations to propagate past AOTAutograd, and fix up all our FX passes to treat them as strict barriers.
  • AI Infra is working on a higher level observability library (which could interface on top of OpenTelemetry), that will have some higher level abstractions that we’ve found quite useful in our data preprocessing stack! More Meta only context: Redirecting...
  • H2 planning is on the horizon. Gregory is interested in working out how we could avoid putting too many commitments on our roadmap.
  • Notable bugs:
  • Notable fixes:
2 Likes