Obviously, to turn on PT2 on some code, you should use torch.compile. But you can also turn off PT2! To disable a block of code from being processed in PT2, use the torch._dynamo.disable
decorator. Example:
import torch
import torch._dynamo
@torch._dynamo.disable
def f(x, y):
return x + y
def forward(x, y):
x = x * 2
r = f(x, y)
r = r * y
return r
fn_compiled = torch.compile(forward)
x = torch.randn(3)
y = torch.randn(3)
print(fn_compiled(x, y))
If you run this code with TORCH_LOGS=dynamo,graph
, you will see this trace:
[2023-04-11 08:04:08,684] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-04-11 08:04:08,691] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-04-11 08:04:17,088] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-04-11 08:04:17,094] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
__compiled_fn_0 <eval_with_key>.3 opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder l_x_ L_x_ () {}
call_function mul <built-in function mul> (l_x_, 2) {}
output output output ((mul,),) {}
[2023-04-11 08:04:17,100] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <resume in forward>
[2023-04-11 08:04:17,102] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing <resume in forward> (RETURN_VALUE)
[2023-04-11 08:04:17,103] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-04-11 08:04:18,528] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-04-11 08:04:18,529] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
__compiled_fn_2 <eval_with_key>.11 opcode name target args kwargs
------------- --------- ----------------------- ----------------- --------
placeholder l_stack0_ L_stack0_ () {}
placeholder l_y_ L_y_ () {}
call_function mul <built-in function mul> (l_stack0_, l_y_) {}
output output output ((mul,),) {}
Dynamo has split your logic into two graphs, and avoided compiling f (with the addition) entirely.
If you have a block of code in PT2 that is causing PT2 to crash, consider disabling PT2 on it. It is always sound to do so, you will only impede optimizations in the same way a graph break would have impeded optimization. Oh, and don’t forget to send us a bug report; we should never be crashing on user code!