The "Ideal" PyTorch FLOP Counter (with __torch_dispatch__)

TL;DR: I wrote a flop counter in 130 lines of Python that 1. counts FLOPS at an operator level, 2. (optionally) aggregates them in a module hierarchy, 3. captures backwards FLOPS, and 4. works in eager-mode. Oh, and you can use it under arbitrary transformations (such as vmap) to compute FLOPS for say, jacobians or hessians too!

For the impatient, here it is (note that you need PyTorch nightly to run it).

So why “Another PyTorch Flop Counter”?

There have been many flop counters built in PyTorch over the years (see flops-counter.pytorch, pytorch-OpCounter, Deepspeed FLOPs profiler, fvcore flop counter’s, or this Pytorch issue with 56 thumbs up). Yet… none of these allow me to answer a somewhat reasonable question:

How many flops do I need in my backwards pass?

The common rule of thumb here is 2x, but there are several things that affect this. 1. Your first layer doesn’t require you to propagate weights back to the input, 2. Activation checkpointing can mess things up as well (particularly more … advanced activation checkpointing schemes like the one in AOTAutograd )

So… in addition to that, what other requirements would we be on my wishlist for a “perfect FLOP counter”? I’ll mark the previous approaches that satisfy these requirements.

  • Captures all operators, and not just modules (fvcore)
    • It would be nice if we didn’t just … miss every time an user called torch.mm instead of using nn.Linear
  • Captures module hierarchies (fvcore)
    • As FVCore motivates “nn.Module is the level of abstraction where users design models. To help design efficient models, providing flops per nn.Module in a recursive hierarchy is needed.”
  • (To repeat myself) Works for the backwards pass
    • Why is this soooo hard?
  • (Ideally) Works for other fancier gradient computations too, like jacobians or hessian vector products. Also, works with activation checkpointing.
  • (Ideally) Works in eager mode, and not just models that can be JIT/FX/AOTAutograd traced.
    • Users may have models that are … weirdly structured, or maybe they don’t have models at all. Or maybe they want to compute average FLOPS across a model with control flow. Nevertheless, I would not need to impose a tracing restriction in my “perfect” FLOP counter.
  • Is super simple and hackable.
    • If I couldn’t implement a “perfect” FLOP counter in a couple hours, that seems like too much work. Similarly, if I needed to touch C++, that also seems like too much work.

I’ll note that there’s one requirement here that we don’t fit - computing FLOPS with almost no overhead. If we wanted to capture FLOPS across all of our models running in production, or wanted to compute FLOPS at the same time that we’re computing other low-level quantities, we’d need to compute FLOPS with very minimal overhead (which this doesn’t fit). There, we’d look for an approach along the lines of the PyTorch Profiler.

How would we implement such a FLOP counter?

So, historically, why have existing implementations not been able to satisfy these requirements? Well, FLOP counting is basically the act of tracking what PyTorch has done, and fundamentally, they’ve all just been working with “lossy” representations of “what PyTorch has done”. These approaches tried to approximate “what PyTorch has done” with “what modules PyTorch has run”, or “a Torchscript or FX trace of the module”.

Well, with the magic of __torch_dispatch__, no longer! Now, simply … count what FLOPS happen to your tensor. And suddenly, nearly all of the above desiderata (capture all operators, and works for backwards/jacobian/checkpointing) are instantly satisfied. The core of the code looks like this:

def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
    def unwrap(e):
        return e.elem if isinstance(e, FlopTensor) else e

    rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
    outs = normalize_tuple(rs)

    if func in flop_mapping:
        global flop_counts
        flop_count = flop_mapping[func](args, outs)
        for par in parents:
            flop_counts[par][func.__name__] += flop_count

    def wrap(e):
        return FlopTensor(e) if isinstance(e, torch.Tensor) else e

    rs = tree_map(wrap, rs)
    return rs

This is sufficient to allow it to already work for forward/backwards/jacobians.

But, what if we wanted a module hierarchy, like FVCore has? Well, in the forwards pass, it’s relatively easy enough - we just add a module pre and post hook that push/pop modules onto the global context. But… how do we extend the module hierarchy to the backwards FLOPS? There’s no module_backward_pre_hook…

Here, the composability of __torch_dispatch__ with the rest of PyTorch once again saves the day. When we add the module pre and post hooks, we also run it through a custom autograd.Function that pushes/pops the module onto the global context in the backwards pass. Like so:

Now, we’re ready to test our flop counter on say, resnet18.

Notice how in the conv1 layer, the convolution_backward FLOPS is equal to the forwards FLOPS (instead of double). So, if you’d solely doubled your FLOPS for your backwards, you would have been off by .1 gigaflops :^)

Since our FlopTensor is just a __torch_dispatch__ tensor subclass, we can run our module however we like.

For example, you could randomly freeze half of your weights:

Conclusions

  1. __torch_dispatch__ is awesome
  2. __torch_dispatch__ is awesome.
  3. __torch_dispatch__ is awesome.
  4. Working with the dispatcher (instead of around it) often makes your life sooo much easier, and makes it much easier to compose with other parts of PyTorch. In other words, torch_dispatch is great.
  5. Specifically, all of these things we often think of as “special cases” just get instantly resolved by using the right abstraction level. I didn’t need to specifically write code to make FlopTensor work with backwards/autograd.Function/jacobians/checkpointing - they just automatically did once I decided to use torch_dispatch. Specifically, composing with autograd.Function made the module hierarchy in backwards possible - I’m not sure how else I could have done it.

Note: This is not a replacement for the PyTorch profiler, which is able to capture these values with sub-microsecond overhead. In comparison, this currently adds about 30-40us per operator. There’s actually an upcoming PyTorch profiler feature that allows you to do this and other cool stuff around profiling performance, so this is primarily useful as a showcase of __torch_dispatch__ and an easy/hackable FLOPS profiler.

We’re currently writing some more resources on the long-term plan for torch_dispatch, but if you have something you think might be a good fit for __torch_dispatch__, come talk to us!

17 Likes

This is an updated version of the FLOP counter that uses TorchDispatchMode instead

1 Like

I’m getting this error:

File ~/.../lib/python3.9/site-packages/torch/utils/_mode_utils.py:28, in _wrap_init.<locals>.wrapped(self, inner, *args, **kwargs)
     25 @functools.wraps(f)
     26 def wrapped(self, *args, inner=undef, **kwargs):
     27     if inner is undef:
---> 28         raise TypeError(
     29             f"missing inner keyword argument; instead of constructing a {meta_init_error_info.mode_class_name} "
     30             f"directly, pass the constructor to push_{meta_init_error_info.mode_name}_mode"
     31         )
     32     self.inner = inner
     33     return f(self, *args, **kwargs)

TypeError: missing inner keyword argument; instead of constructing a TorchDispatchMode directly, pass the constructor to push_torch_dispatch_mode

This reads a bit too cryptic to me, what am I missing?

Interesting. Do you measure memory bandwidth as well?

Some operations like convolution or gemm are mostly flops bounded, however many operations are actually memory bandwidth bounded for example batch normalization, activation etc. I noticed that sometimes GPU like 2060S that has less flops than 1080 can do resnet faster due to large difference in memory speed.

Regarding convolution, do you take in account that if you run for example Wingorad or FFT convolution can actually do it in less FLOPS than “direct”/GEMM based convolution?

Hi! Please try updating to the latest version of PyTorch (1.13 is due out very soon if you’re able to hold on but I think this should also be updated in 1.12.1) and let me know if you’re still seeing this error

Do you measure memory bandwidth as well?

No. Certainly, measuring memory-bandwidth would be a great additional thing to do. In principle, I think it would be fairly easy to extend the FlopCounter above to also track memory bandwidth, although it may be more labor intensive due to the wider variety of operators for which memory bandwidth is relevant.

Regarding convolution, do you take in account that if you run for example Wingorad or FFT convolution can actually do it in less FLOPS than “direct”/GEMM based convolution?

No, and we also don’t track Strassen or anything like that. In my view, the main point of counting FLOPs is to get a quick and dirty roofline estimate of “percentage of FLOPs achieved”. If you’re managing to use Strassen or FFT successfully, then more power to you :slight_smile:

1 Like

Hi! I’m seeing this error on 1.12.1.

Just checked and Horace’s code works with 1.13.0 which was released very recently

If you are emotionally or otherwise tied to 1.12.1, the following works (but you’ll see that the code looks worse, it should actually show why we changed the API):

import torchvision.models as models
from torch.utils._python_dispatch import push_torch_dispatch_mode
from functools import partial

inp = torch.randn(8, 3, 224, 224), device='cuda')
mod = models.resnet18().cuda()
with FlopCounterMode.push(mod) as flop_counter1:
    mod(inp).sum().backward()
 
with FlopCounterMode.push(mod) as flop_counter2:
    mod(inp).sum().backward()
exit(0)
 
# the following won't work since symbolic shapes isn't in <= 1.12.1
# from torch.fx.experimental.symbolic_shapes import ShapeEnv
# from torch._subclasses import FakeTensorMode
# shape_env = ShapeEnv()
 
# with FakeTensorMode.push(shape_env):
#     inp = fake_mode.from_tensor(inp)
#     assert inp.shape[0] == 1
#     mod = models.resnet18()
#     with FlopCounterMode.push(mod):
#         with torch.no_grad():
#             mod(inp)
2 Likes

I think this counter estimates MACs, i.e. multiply-and-accumulate.
To compute FLOPs, you have to 2x this.
Making this fix also correctly lines up the numbers with XLA’s flop counter

Here’s an updated gist with that fix: Horace's flop counter, but with flops metric fixed correctly · GitHub

1 Like

@Chillee / @smth / @albanD can anyone please explain how exactly the module hierarchy is captured in the context of the backward pass ? Specifically, I’m unable to understand how the two functions create_backward_push and create_backward_pop work ?

I’m clear with the capturing module context in the forward pass using the two hooks, but unable to why create_backward_push and create_backward_pop are implemented in that way to help get the flops for bwd pass.

PS : I’m aware of the working of autograd, creation of backward graphs, how to use custom autograd functions and how to manipulate inputs/outputs in the forward/backward static methods of the custom autograd function. And I’m also able to understand the complete flow of the flops counter and why torch dispatch is used.

Thanks a lot !

They register a custom Function in the autograd graph at the boundaries of the Module. This way, during the backward() calls of these functions, we know when we enter and exit the part of the backward pass corresponding to this Module.

Note that you can achieve the same result using Module backward hooks and backward pre-hooks today and have a system that works exactly the same way as the forward one.

1 Like

Hi @albanD thanks a lot for the reply. That makes sense !
I had few more doubts, if you could please help me that’d be great !

  1. In the conv_backward_flop function, I see that we’re calling the conv_flop_count function based on output_mask. Can I please know what is the function of output_mask variable here ?

  2. In the conv_flop_count function, FLOPs for transposed convolutions are calculated in a different way. Can you please confirm what is the difference between transposed convolution and normal convolution ?

  3. In the conv_backward_flop function, I see that for the first case (for output_mask[0]), conv_flop_count is invoked with (grad_output_shape, w_shape, grad_input_shape). But, in the second case, it’s invoked differently. I could not follow why the parameters are passed in that order ?

  4. In the second case of conv_backward_flop function, I couldn’t follow why were are doing transpose_shape(x_shape) while passing it to the conv_flop_count function?

  5. When I traced the code for Lenet5 model, I see that aten.mm is invoked twice for one Fully Connected Layer during the backward pass. I couldn’t fathom why it’s being called twice for every fully-connected layer in the backward pass?

Thanks a lot !

cc : @Chillee | @smth

Hi,

When you do out = mm(x, y) during the forward pass, then the backward pass is give gOut and needs to compute gX and gY.
The formula for this is gX = mm(gOut, y^T) and gY = mm(x^T, gOut) that should explain why you have 2 mms in the backward pass.
Also since convolutions are just special mm, you get very similar formulas and that’s where the transpose comes from.
I would recommend you check online for the difference between transposed and regular convolutions. There are blogpost with visualizations that will be much better than anything I could write!

Cheers,
Alban

1 Like

Hi, thanks a lot for replying. I was able to do some digging and was able to derive the equations for backward pass for FCL and Conv2d. Dropping them here in case anyone is interested !

FCLBackward pass

  1. Matrix-matrix multiply: dX = dY * W.T, FLOPs = W.nrows * W.ncols
  2. Matrix-matrix multiply: dW = X.T * dY, FLOPs = W.nrows * W.ncols

Conv2d Backward pass

  1. Matrix-matrix multiple: dX = F.T[CRS, K] * dY[K, NPQ], FLOPs = CRSKNPQ
  2. Matrix-matrix multiple: dW = dY[K, NPQ] * D.T[NPQ, CRS], FLOPs = KNPQCRS = CRSKNPQ

Conventions :

  • K is the number of output feature maps (number of filters).
  • N is the batch size. It’s assumed that total FLOPs are proportional to batch size, so we can safely assume it equals to 1.
  • P is the height of an output feature map.
  • Q is the width of an output feature map.
  • C is the number of input feature maps (depth of input rank 3 tensor, same as Filter.Depth).
  • R is the filter height (Filter.Height).
  • S is the filter width (Filter.Width).
  • Y is the output feature map (Output).
  • F is the filter tensor (Filter).
  • D is the input data (X)
  • T means transpose of matrix
1 Like

@Chillee @ezyang

Is it possible to run the FlopCounter with FakeTensors on a model initialized on a “meta” device? Use case is to get a sense of flop count distribution without having to instantiate full model on device.

Yes this should work! It actually can also work with symbolic shapes - @eellison had a pretty cool demo of it.

@Chillee – thanks!

Have a large model and would be helpful to get a quick estimate of flops distribution with minimal setup.

Do you have a link to the demo you could share @eellison?

Something like this might help torch_flops/torch_flops/flops_engine.py at dc72fb62934e107987fc3e9cb59d74d32b3910ef · zugexiaodui/torch_flops · GitHub

@Chillee Related question: Is there a way for me extract the post fused FX graph or better the inductor scheduled graph (I know there’s TORCH_COMPILE_DEBUG=1 but was looking at a programmatic way) ?

Interested in extending to byte counting while factoring in as much of fusion stuff as I can.

@jeromeku Google Colab here it is