Difference between the graph break reasons: `Dynamic control flow is not supported at the moment.` and `generic_jump TensorVariable()`

I was trying to understand the reason behind graph breaks, where I came across certain graph breaks in models from PyTorch Benchmarks.

export TORCH_LOGS="graph_breaks"

Case 1: BartForCausalLM (huggingface.py)
Command:

pytorch$ ./benchmarks/dynamo/huggingface.py --performance --training --amp --backend=inductor --only=BartForCausalLM

I get the following:

torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks] Graph break: from user code at:
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks]   File "/media/disk1/abhishek/pytorch/benchmarks/dynamo/huggingface.py", line 564, in torch_dynamo_resume_in_forward_and_backward_pass_at_562
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks]     pred = mod(**cloned_inputs)
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks]   File "/home/abhishek/pytorch-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks]     return forward_call(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks]   File "/home/abhishek/pytorch-venv/lib/python3.10/site-packages/transformers/models/bart/modeling_bart.py", line 2245, in forward
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks]     outputs = self.model.decoder(
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks]   File "/home/abhishek/pytorch-benchmarks/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks]     return forward_call(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks]   File "/home/abhishek/pytorch-venv/lib/python3.10/site-packages/transformers/models/bart/modeling_bart.py", line 1451, in forward
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks]     if dropout_probability < self.layerdrop:
...
...
torch/_dynamo/symbolic_convert.py:527] [2/0] [__graph_breaks] torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

So, it complains about the if condition here:

I tried to produce it with a minimalistic example:

@torch.compile(mode="reduce-overhead")
def f(x, y):
    z = torch.rand([])
    if z > 0.5:
        return x + y
    else:
        return x - y
x = torch.tensor([1.0, 2.0], device="cuda")
y = torch.tensor([1.0, 2.0], device="cuda")
f(x,y)

The simple export TORCH_LOGS="graph_breaks" didn’t print out anything. So, using “+dynamo” I got:

torch/_dynamo/symbolic_convert.py:774] [5/0] [__trace_source]         z = torch.rand([])
torch/_dynamo/symbolic_convert.py:797] [5/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
torch/_dynamo/symbolic_convert.py:797] [5/0] [__trace_bytecode] TRACE LOAD_ATTR rand [PythonModuleVariable(<module 'torch' from '/media/abhishek/pytorch-venv/lib/python3.10/site-packages/torch/__init__.py'>)]
torch/_dynamo/symbolic_convert.py:797] [5/0] [__trace_bytecode] TRACE BUILD_LIST 0 [TorchInGraphFunctionVariable(<built-in method rand of type object at 0x7fe1fda66500>)]
torch/_dynamo/symbolic_convert.py:797] [5/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [TorchInGraphFunctionVariable(<built-in method rand of type object at 0x7fe1fda66500>), ListVariable(length=0)]
torch/_dynamo/symbolic_convert.py:797] [5/0] [__trace_bytecode] TRACE STORE_FAST z [TensorVariable()]
torch/_dynamo/symbolic_convert.py:774] [5/0] [__trace_source] TRACE starts_line /tmp/ipykernel_1940666/3830298969.py:4 in f
torch/_dynamo/symbolic_convert.py:774] [5/0] [__trace_source]         if z > 0.5:
torch/_dynamo/symbolic_convert.py:797] [5/0] [__trace_bytecode] TRACE LOAD_FAST z []
torch/_dynamo/symbolic_convert.py:797] [5/0] [__trace_bytecode] TRACE LOAD_CONST 0.5 [TensorVariable()]
torch/_dynamo/symbolic_convert.py:797] [5/0] [__trace_bytecode] TRACE COMPARE_OP > [TensorVariable(), ConstantVariable()]
torch/_dynamo/symbolic_convert.py:797] [5/0] [__trace_bytecode] TRACE POP_JUMP_IF_FALSE 26 [TensorVariable()]
torch/_dynamo/symbolic_convert.py:322] [5/0] generic_jump triggered compile
torch/_dynamo/output_graph.py:971] [5/0] COMPILING GRAPH due to GraphCompileReason(reason='generic_jump TensorVariable()', user_stack=[<FrameSummary file /tmp/ipykernel_1940666/3830298969.py, line 4 in f>], graph_break=True)

Why this difference?

Similarly,

Case 2: detectron2_fasterrcnn_r_101_c4 (torchbench.py)

pytorch$ ./benchmarks/dynamo/torchbench.py --performance --inference --amp --backend=inductor --only=detectron2_fasterrcnn_r_101_c4```

Gives out the following graph break:

torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks] Graph break: from user code at:
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]   File "/home/abhishek/pytorch-venv/lib/python3.10/site-packages/detectron2/modeling/meta_arch/rcnn.py", line 208, in torch_dynamo_resume_in_inference_at_203
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]     proposals, _ = self.proposal_generator(images, features, None)
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]   File "/home/abhishek/pytorch-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]     return forward_call(*args, **kwargs)
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]   File "/home/abhishek/pytorch-venv/lib/python3.10/site-packages/detectron2/modeling/proposal_generator/rpn.py", line 477, in forward
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]     proposals = self.predict_proposals(
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]   File "/home/abhishek/pytorch-venv/lib/python3.10/site-packages/detectron2/modeling/proposal_generator/rpn.py", line 503, in predict_proposals
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]     return find_top_rpn_proposals(
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]   File "/home/abhishek/pytorch-venv/lib/python3.10/site-packages/detectron2/modeling/proposal_generator/proposal_utils.py", line 106, in find_top_rpn_proposals
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]     if not valid_mask.all():
...
...
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]     self.dispatch_table[inst.opcode](self, inst)
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]   File "/home/abhishek/pytorch-venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in inner
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks]     raise exc.UserError(
torch/_dynamo/symbolic_convert.py:527] [8/0] [__graph_breaks] torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

Actually, it complains of the following if condition:

But when I tried to create a minimalistic reproduction:

@torch.compile(mode="reduce-overhead")
def f(x, y):
    if (x == y).all():
        return x + y 
    else:
        return x - y
x = torch.tensor([1.0, 2.0], device="cuda")
y = torch.tensor([1.0, 2.0], device="cuda")
f(x,y)

The situation is same as Case 1, “graph_breaks” option did not say anything but, “+dynamo” showed the following:

torch/_dynamo/symbolic_convert.py:774] [3/0] [__trace_source] TRACE starts_line /tmp/ipykernel_1940666/2411286153.py:3 in f
torch/_dynamo/symbolic_convert.py:774] [3/0] [__trace_source]         if (x == y).all():
torch/_dynamo/symbolic_convert.py:797] [3/0] [__trace_bytecode] TRACE LOAD_FAST x []
torch/_dynamo/symbolic_convert.py:797] [3/0] [__trace_bytecode] TRACE LOAD_FAST y [LazyVariableTracker()]
torch/_dynamo/symbolic_convert.py:797] [3/0] [__trace_bytecode] TRACE COMPARE_OP == [LazyVariableTracker(), LazyVariableTracker()]
torch/_dynamo/output_graph.py:2029] [3/0] create_graph_input L_x_ L['x']
torch/_dynamo/variables/builder.py:2268] [3/0] wrap_to_fake L['x'] (2,) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>], constraint_sizes=[None], view_base_context=None, tensor_source=LocalSource(local_name='x', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
torch/_dynamo/output_graph.py:2029] [3/0] create_graph_input L_y_ L['y']
torch/_dynamo/variables/builder.py:2268] [3/0] wrap_to_fake L['y'] (2,) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>], constraint_sizes=[None], view_base_context=None, tensor_source=LocalSource(local_name='y', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
torch/_dynamo/symbolic_convert.py:797] [3/0] [__trace_bytecode] TRACE LOAD_ATTR all [TensorVariable()]
torch/_dynamo/symbolic_convert.py:797] [3/0] [__trace_bytecode] TRACE CALL_FUNCTION 0 [GetAttrVariable()]
torch/_dynamo/symbolic_convert.py:797] [3/0] [__trace_bytecode] TRACE POP_JUMP_IF_FALSE 20 [TensorVariable()]
torch/_dynamo/symbolic_convert.py:322] [3/0] generic_jump triggered compile
torch/_dynamo/output_graph.py:971] [3/0] COMPILING GRAPH due to GraphCompileReason(reason='generic_jump TensorVariable()', user_stack=[<FrameSummary file /tmp/ipykernel_1940666/2411286153.py, line 3 in f>], graph_break=True)

Why this difference?