Does dynamo trigger real kernel execution?

I think the answer is NO according to Dynamo Overview — PyTorch 2.7 documentation

Dynamo hooks into the frame evaluation API in CPython (PEP 523) to dynamically modify Python bytecode right before it is executed. It rewrites Python bytecode to extract sequences of PyTorch operations into an FX Graph which is then compiled with a customizable backend.

The above statement implies two steps:

  1. dynamo converts the python bytecode to FX graph w/o real kernel execution. And TORCH_LOGS="+dynamo" shows what the FX graph looks like.
  2. the FX graph is then compiled and then real kernel execution is triggered.

But, looks that there’s a real kernel execution to generate the FX graph (in step 1 above) with my experiment, it does not align with the above statement.

In the experiment below, I checked what the FX graph is to guess what happened behind.

Since I only care about dynamo, so I choose the backend in code below as ‘eager’, not involving AOTAutograd/inductor.

import torch

class TestNet(torch.nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.linear = torch.nn.Linear(3, 2)

    def forward(self, x):
        x = x * 2
        if x.sum() > 0:
            x = torch.sin(x)
        if x.mean() > 0:
            x = torch.cos(x)
        x = self.linear(x)
        return x

m = TestNet().cuda()
m = torch.compile(m, backend="eager")
inputs = torch.ones(3).cuda() * -1.0
m(inputs)

With TORCH_LOGS="+dynamo" python -u eagerbackend.py , I do not see anything about sin and cos in the log. (there’s sin and cos in the log if we remove *-1.0 in the code for inputs)

How could dynamo know that sin and cos is not in the FX graph with the specific inputs? My only guess is that dynamo triggers real kernel execution for x.sum() and x.mean() etc., and so it can generate the FX graph w/o ‘sin’ and ‘cos’. But it does not align with the statement that dynamo generates FX graph before it is executed.

What really happens behind? thanks.

Hey!

@anijain2305 is the expert on this but my mental model is as follows:

While it interprets the bytecode itself, it does trace the effect of a lot of the bytecode that is not trivial (with fake objects). So it does executing some things in some cases.

In this case, what happens is different, the data-dependent control flows that you have trigger a “graph break”. Meaning that dynamo cannot capture a single graph of all of this and just creates a collection of smaller graphs and preserves the branching in python as-is.

That’s why the logs are so long in this case: there are many small graphs being created.
Why you don’t see the sin/cos inside these logs is a good question though :smiley: @anijain2305 might know

do you mean the execution with fake objects (no real data) to deduce the tensor shape? thanks.

It does yes and all the other Tensor metadata as well.

@anijain2305 could you help? thanks.