I am a bit confused about graphs breaks with dynamic shapes. I was looking through the documentation of Dynamic Shapes. I was having a look at the Guard Model there.
For the program snippet:
def f(x, y):
z = torch.cat([x, y])
if z.size(0) > 2:
return z.mul(2)
else:
return z.add(2)
The documentation says:
The final IR we will compile with TorchInductor will either be
torch.cat([x, y]).add(2)
ortorch.cat([x, y]).mul(2)
(with the condition flattened away)
And I felt that there was no such graph break, and we have two different graphs altogether.
I tried to test things out with a few example code sections:
@torch.compile(mode="reduce-overhead")
def f(x, y):
z = torch.cat([x, y])
if z.size(0) > 2:
return z.mul(2)
else:
return z.add(2)
for _ in range(5):
print(f"Iteration::{_} Running the multiply case...")
x = torch.randn(2)
y = torch.randn(2)
z = f(x, y)
print(z)
print(f"Iteration::{_} Running the add case...")
x = torch.rand(1)
y = torch.rand(1)
z = f(x, y)
print(z)
The above code is pretty simple. And I tried to dig into it a bit deeper:
from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn
cache_entries = _debug_get_cache_entry_list(innermost_fn(f))
cache_entries
The code shows there are two cache entries for f
:
[<torch._C._dynamo.eval_frame._CacheEntry at 0x7fac1c3b0b30>,
<torch._C._dynamo.eval_frame._CacheEntry at 0x7fac1c322af0>]
Trying to dig deeper:
import depyf
for _ in range(len(cache_entries)):
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
print("Stats for cache entry::", _)
cache_entry = cache_entries[_]
guard, code = cache_entry.check_fn, cache_entry.code
print(guard)
print("--------------------------------------")
print("The following source code printed using inspect.getsource: remains the same\n"
"as the original source code.")
import inspect
print(inspect.getsource(code))
print("--------------------------------------")
print("The following is the bytecode printed as source code using depyf.decompile.")
print(depyf.decompile(code))
print("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
I get the following:
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Stats for cache entry:: 0
TREE_GUARD_MANAGER:
+- RootGuardManager
| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:459 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state()
| +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[1], stride=[1]) # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
| | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
| +- GuardManager: source=L['y'], accessed_by=DictGetItemGuardAccessor(y)
| | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[1], stride=[1]) # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
| | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
| | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
| +- GuardManager: source=G, accessed_by=GlobalsGuardAccessor
| | +- GuardManager: source=G['torch'], accessed_by=DictGetItemGuardAccessor(torch)
| | | +- ID_MATCH: ___check_obj_id(G['torch'], 140380820346000) # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
| | | +- GuardManager: source=G['torch'].cat, accessed_by=GetAttrGuardAccessor(cat)
| | | | +- ID_MATCH: ___check_obj_id(G['torch'].cat, 140380810422096) # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
--------------------------------------
The following source code printed using inspect.getsource: remains the same
as the original source code.
@torch.compile(mode="reduce-overhead")
def f(x, y):
z = torch.cat([x, y])
if z.size(0) > 2:
return z.mul(2)
else:
return z.add(2)
--------------------------------------
The following is the bytecode printed as source code using depyf.decompile.
def f(x, y):
return __compiled_fn_3(x, y)
<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
Stats for cache entry:: 1
TREE_GUARD_MANAGER:
+- RootGuardManager
| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:459 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state()
| +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[2], stride=[1]) # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
| | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
| +- GuardManager: source=L['y'], accessed_by=DictGetItemGuardAccessor(y)
| | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[2], stride=[1]) # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
| | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
| | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
| +- GuardManager: source=G, accessed_by=GlobalsGuardAccessor
| | +- GuardManager: source=G['torch'], accessed_by=DictGetItemGuardAccessor(torch)
| | | +- ID_MATCH: ___check_obj_id(G['torch'], 140380820346000) # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
| | | +- GuardManager: source=G['torch'].cat, accessed_by=GetAttrGuardAccessor(cat)
| | | | +- ID_MATCH: ___check_obj_id(G['torch'].cat, 140380810422096) # z = torch.cat([x, y]) # mp/ipykernel_1672598/1182209580.py:7 in f
--------------------------------------
The following source code printed using inspect.getsource: remains the same
as the original source code.
@torch.compile(mode="reduce-overhead")
def f(x, y):
z = torch.cat([x, y])
if z.size(0) > 2:
return z.mul(2)
else:
return z.add(2)
--------------------------------------
The following is the bytecode printed as source code using depyf.decompile.
def f(x, y):
return __compiled_fn_1(x, y)
<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
I see that the cached version of f
are cases with __compiled_fn_3
and __compiled_fn_1
in them.
I do not know:
- why isn’t there a graph break?
- The two guards in this case (listed above) seem the same to me. What is the section of code that decides which version of cache is to be executed?
(with the condition flattened away),
but where is it getting executed? I assumed it to be inside the guard? [Ever since guard management moved on to C++, I find it a bit difficult to understand the guard output. Just as in this case, where is the condition?]
Of similar flavor I have the following example as well:
def conditional(x, y):
if x.shape[0] > 2:
return x + y
else:
return x - y
conditional_torch_compiled = torch.compile(conditional, mode="reduce-overhead")
x = torch.tensor( [ [1., 2.], [3., 4.], [5., 6.] ], device='cuda')
x_1 = torch.tensor( [ [1., 2.], [3., 4.] ], device='cuda')
x_2 = torch.tensor( [ [1., 2.]], device='cuda')
y = torch.tensor( [1., 2.], device='cuda')
Case 1: x.shape[0] > 2
conditional_torch_compiled(x, y)
Guard and CUDAGraph:
I0827 15:37:01.507000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3633] [1/0] produce_guards
V0827 15:37:01.507000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/0] track_symint L['x'].size()[0] 3 None
V0827 15:37:01.508000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/0] track_symint L['x'].size()[1] 2 None
V0827 15:37:01.508000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/0] track_symint L['x'].stride()[0] 2 None
V0827 15:37:01.509000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/0] track_symint L['x'].stride()[1] 1 None
V0827 15:37:01.510000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/0] track_symint L['x'].storage_offset() 0 None
V0827 15:37:01.510000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/0] track_symint L['y'].size()[0] 2 None
V0827 15:37:01.510000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/0] track_symint L['y'].stride()[0] 1 None
V0827 15:37:01.511000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/0] track_symint L['y'].storage_offset() 0 None
V0827 15:37:01.511000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/0] Skipping guard L['x'].size()[0] == 3
V0827 15:37:01.512000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/0] Skipping guard L['x'].size()[1] == 2
V0827 15:37:01.512000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/0] Skipping guard L['x'].stride()[0] == 2
V0827 15:37:01.513000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/0] Skipping guard L['x'].stride()[1] == 1
V0827 15:37:01.513000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/0] Skipping guard L['x'].storage_offset() == 0
V0827 15:37:01.515000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/0] Skipping guard L['y'].size()[0] == 2
V0827 15:37:01.515000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/0] Skipping guard L['y'].stride()[0] == 1
V0827 15:37:01.516000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/0] Skipping guard L['y'].storage_offset() == 0
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2168] [1/0] [__guards] GUARDS:
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards]
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] TREE_GUARD_MANAGER:
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] +- RootGuardManager
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:459 in init_ambient_guards
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] | +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[3, 2], stride=[2, 1]) # if x.shape[0] > 2: # mp/ipykernel_1924114/1524425904.py:2 in conditional
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # if x.shape[0] > 2: # mp/ipykernel_1924114/1524425904.py:2 in conditional
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] | +- GuardManager: source=L['y'], accessed_by=DictGetItemGuardAccessor(y)
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2], stride=[1]) # return x + y # mp/ipykernel_1924114/1524425904.py:3 in conditional
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] | | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False # return x + y # mp/ipykernel_1924114/1524425904.py:3 in conditional
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0827 15:37:01.517000 140504121249792 torch/_dynamo/guards.py:2147] [1/0] [__guards]
I0827 15:37:01.519000 140504121249792 torch/_inductor/cudagraph_trees.py:362] [__cudagraphs] recording cudagraph tree for graph without symints
Case 2: x.shape[0] == 2
:
I0827 15:42:06.628000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3633] [1/1] produce_guards
V0827 15:42:06.628000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/1] track_symint L['x'].size()[0] s0 RelaxedUnspecConstraint(warn_only=True)
V0827 15:42:06.629000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/1] track_symint L['x'].size()[1] 2 None
V0827 15:42:06.629000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/1] track_symint L['x'].stride()[0] 2 None
V0827 15:42:06.630000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/1] track_symint L['x'].stride()[1] 1 None
V0827 15:42:06.630000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/1] track_symint L['x'].storage_offset() 0 None
V0827 15:42:06.630000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/1] track_symint L['y'].size()[0] 2 None
V0827 15:42:06.631000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/1] track_symint L['y'].stride()[0] 1 None
V0827 15:42:06.631000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/1] track_symint L['y'].storage_offset() 0 None
V0827 15:42:06.631000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/1] Skipping guard L['x'].size()[1] == 2
V0827 15:42:06.632000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/1] Skipping guard L['x'].stride()[0] == 2
V0827 15:42:06.632000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/1] Skipping guard L['x'].stride()[1] == 1
V0827 15:42:06.632000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/1] Skipping guard L['x'].storage_offset() == 0
V0827 15:42:06.633000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/1] Skipping guard L['y'].size()[0] == 2
V0827 15:42:06.633000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/1] Skipping guard L['y'].stride()[0] == 1
V0827 15:42:06.634000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/1] Skipping guard L['y'].storage_offset() == 0
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2168] [1/1] [__guards] GUARDS:
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards]
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] TREE_GUARD_MANAGER:
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] +- RootGuardManager
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:459 in init_ambient_guards
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | | +- TYPE_MATCH: ___check_type_id(L['x'], 94672449896992) # if x.shape[0] > 2: # mp/ipykernel_1924114/1524425904.py:2 in conditional
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[None, 2], stride=[2, 1]) # if x.shape[0] > 2: # mp/ipykernel_1924114/1524425904.py:2 in conditional
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # if x.shape[0] > 2: # mp/ipykernel_1924114/1524425904.py:2 in conditional
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | +- GuardManager: source=L['y'], accessed_by=DictGetItemGuardAccessor(y)
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2], stride=[1]) # return x - y # mp/ipykernel_1924114/1524425904.py:5 in conditional
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False # return x - y # mp/ipykernel_1924114/1524425904.py:5 in conditional
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards] +- LAMBDA_GUARD: 2 <= L['x'].size()[0] <= 2 # _dynamo/output_graph.py:451 in init_ambient_guards
V0827 15:42:06.637000 140504121249792 torch/_dynamo/guards.py:2147] [1/1] [__guards]
I0827 15:42:06.639000 140504121249792 torch/_inductor/cudagraph_trees.py:364] [__cudagraphs] recording cudagraph tree for symint key 2
Case 3: x.shape[0] < 2
I0827 15:56:04.477000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3633] [1/2] produce_guards
V0827 15:56:04.477000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/2] track_symint L['x'].size()[0] 1 RelaxedUnspecConstraint(warn_only=True)
V0827 15:56:04.478000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/2] track_symint L['x'].size()[1] 2 None
V0827 15:56:04.478000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/2] track_symint L['x'].stride()[0] 2 None
V0827 15:56:04.479000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/2] track_symint L['x'].stride()[1] 1 None
V0827 15:56:04.479000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/2] track_symint L['x'].storage_offset() 0 None
V0827 15:56:04.480000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/2] track_symint L['y'].size()[0] 2 None
V0827 15:56:04.480000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/2] track_symint L['y'].stride()[0] 1 None
V0827 15:56:04.480000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3815] [1/2] track_symint L['y'].storage_offset() 0 None
V0827 15:56:04.481000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/2] Skipping guard L['x'].size()[0] == 1
V0827 15:56:04.481000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/2] Skipping guard L['x'].size()[1] == 2
V0827 15:56:04.482000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/2] Skipping guard L['x'].stride()[0] == 2
V0827 15:56:04.482000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/2] Skipping guard L['x'].stride()[1] == 1
V0827 15:56:04.482000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/2] Skipping guard L['x'].storage_offset() == 0
V0827 15:56:04.483000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/2] Skipping guard L['y'].size()[0] == 2
V0827 15:56:04.483000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/2] Skipping guard L['y'].stride()[0] == 1
V0827 15:56:04.484000 140504121249792 torch/fx/experimental/symbolic_shapes.py:3979] [1/2] Skipping guard L['y'].storage_offset() == 0
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2168] [1/2] [__guards] GUARDS:
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards]
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] TREE_GUARD_MANAGER:
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] +- RootGuardManager
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:459 in init_ambient_guards
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] | +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[1, 2], stride=[2, 1]) # if x.shape[0] > 2: # mp/ipykernel_1924114/1524425904.py:2 in conditional
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # if x.shape[0] > 2: # mp/ipykernel_1924114/1524425904.py:2 in conditional
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] | +- GuardManager: source=L['y'], accessed_by=DictGetItemGuardAccessor(y)
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] | | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2], stride=[1]) # return x - y # mp/ipykernel_1924114/1524425904.py:5 in conditional
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] | | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False # return x - y # mp/ipykernel_1924114/1524425904.py:5 in conditional
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0827 15:56:04.485000 140504121249792 torch/_dynamo/guards.py:2147] [1/2] [__guards]
I0827 15:56:04.488000 140504121249792 torch/_inductor/cudagraph_trees.py:362] [__cudagraphs] recording cudagraph tree for graph without symints
- In this example, why are there 3 different graphs? Also, one of them has
recording cudagraph with symint 2
while the other two haverecording cudagraph without symint
. What is the difference between the two? - Also the guards are same for case 1 and 3 but for case 2 we have an extra line:
LAMBDA_GUARD: 2 <= L['x'].size()[0] <= 2
. (I assumed such a condition to be there in all the guards) - Who is deciding which cached version to use? the statements like:
Skipping guard L['x'].size()[0] == 1
Skipping guard L['x'].size()[0] == 2
Skipping guard L['x'].size()[0] == 3
These things are not clear to me. Please can anyone clarify it?