The "Ideal" PyTorch FLOP Counter (with __torch_dispatch__)

TL;DR: I wrote a flop counter in 130 lines of Python that 1. counts FLOPS at an operator level, 2. (optionally) aggregates them in a module hierarchy, 3. captures backwards FLOPS, and 4. works in eager-mode. Oh, and you can use it under arbitrary transformations (such as vmap) to compute FLOPS for say, jacobians or hessians too!

For the impatient, here it is (note that you need PyTorch nightly to run it).

So why “Another PyTorch Flop Counter”?

There have been many flop counters built in PyTorch over the years (see flops-counter.pytorch, pytorch-OpCounter, Deepspeed FLOPs profiler, fvcore flop counter’s, or this Pytorch issue with 56 thumbs up). Yet… none of these allow me to answer a somewhat reasonable question:

How many flops do I need in my backwards pass?

The common rule of thumb here is 2x, but there are several things that affect this. 1. Your first layer doesn’t require you to propagate weights back to the input, 2. Activation checkpointing can mess things up as well (particularly more … advanced activation checkpointing schemes like the one in AOTAutograd )

So… in addition to that, what other requirements would we be on my wishlist for a “perfect FLOP counter”? I’ll mark the previous approaches that satisfy these requirements.

  • Captures all operators, and not just modules (fvcore)
    • It would be nice if we didn’t just … miss every time an user called torch.mm instead of using nn.Linear
  • Captures module hierarchies (fvcore)
    • As FVCore motivates “nn.Module is the level of abstraction where users design models. To help design efficient models, providing flops per nn.Module in a recursive hierarchy is needed.”
  • (To repeat myself) Works for the backwards pass
    • Why is this soooo hard?
  • (Ideally) Works for other fancier gradient computations too, like jacobians or hessian vector products. Also, works with activation checkpointing.
  • (Ideally) Works in eager mode, and not just models that can be JIT/FX/AOTAutograd traced.
    • Users may have models that are … weirdly structured, or maybe they don’t have models at all. Or maybe they want to compute average FLOPS across a model with control flow. Nevertheless, I would not need to impose a tracing restriction in my “perfect” FLOP counter.
  • Is super simple and hackable.
    • If I couldn’t implement a “perfect” FLOP counter in a couple hours, that seems like too much work. Similarly, if I needed to touch C++, that also seems like too much work.

I’ll note that there’s one requirement here that we don’t fit - computing FLOPS with almost no overhead. If we wanted to capture FLOPS across all of our models running in production, or wanted to compute FLOPS at the same time that we’re computing other low-level quantities, we’d need to compute FLOPS with very minimal overhead (which this doesn’t fit). There, we’d look for an approach along the lines of the PyTorch Profiler.

How would we implement such a FLOP counter?

So, historically, why have existing implementations not been able to satisfy these requirements? Well, FLOP counting is basically the act of tracking what PyTorch has done, and fundamentally, they’ve all just been working with “lossy” representations of “what PyTorch has done”. These approaches tried to approximate “what PyTorch has done” with “what modules PyTorch has run”, or “a Torchscript or FX trace of the module”.

Well, with the magic of __torch_dispatch__, no longer! Now, simply … count what FLOPS happen to your tensor. And suddenly, nearly all of the above desiderata (capture all operators, and works for backwards/jacobian/checkpointing) are instantly satisfied. The core of the code looks like this:

def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
    def unwrap(e):
        return e.elem if isinstance(e, FlopTensor) else e

    rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
    outs = normalize_tuple(rs)

    if func in flop_mapping:
        global flop_counts
        flop_count = flop_mapping[func](args, outs)
        for par in parents:
            flop_counts[par][func.__name__] += flop_count

    def wrap(e):
        return FlopTensor(e) if isinstance(e, torch.Tensor) else e

    rs = tree_map(wrap, rs)
    return rs

This is sufficient to allow it to already work for forward/backwards/jacobians.

But, what if we wanted a module hierarchy, like FVCore has? Well, in the forwards pass, it’s relatively easy enough - we just add a module pre and post hook that push/pop modules onto the global context. But… how do we extend the module hierarchy to the backwards FLOPS? There’s no module_backward_pre_hook…

Here, the composability of __torch_dispatch__ with the rest of PyTorch once again saves the day. When we add the module pre and post hooks, we also run it through a custom autograd.Function that pushes/pops the module onto the global context in the backwards pass. Like so:

Now, we’re ready to test our flop counter on say, resnet18.

Notice how in the conv1 layer, the convolution_backward FLOPS is equal to the forwards FLOPS (instead of double). So, if you’d solely doubled your FLOPS for your backwards, you would have been off by .1 gigaflops :^)

Since our FlopTensor is just a __torch_dispatch__ tensor subclass, we can run our module however we like.

For example, you could randomly freeze half of your weights:

Conclusions

  1. __torch_dispatch__ is awesome
  2. __torch_dispatch__ is awesome.
  3. __torch_dispatch__ is awesome.
  4. Working with the dispatcher (instead of around it) often makes your life sooo much easier, and makes it much easier to compose with other parts of PyTorch. In other words, torch_dispatch is great.
  5. Specifically, all of these things we often think of as “special cases” just get instantly resolved by using the right abstraction level. I didn’t need to specifically write code to make FlopTensor work with backwards/autograd.Function/jacobians/checkpointing - they just automatically did once I decided to use torch_dispatch. Specifically, composing with autograd.Function made the module hierarchy in backwards possible - I’m not sure how else I could have done it.

Note: This is not a replacement for the PyTorch profiler, which is able to capture these values with sub-microsecond overhead. In comparison, this currently adds about 30-40us per operator. There’s actually an upcoming PyTorch profiler feature that allows you to do this and other cool stuff around profiling performance, so this is primarily useful as a showcase of __torch_dispatch__ and an easy/hackable FLOPS profiler.

We’re currently writing some more resources on the long-term plan for torch_dispatch, but if you have something you think might be a good fit for __torch_dispatch__, come talk to us!

4 Likes