Is there a way other than checking the torch.compile() input’s device type to pass the compilation target to the backend?
The dynamo backend I am working on can operate on multiple targets and I need to somehow differentiate between them. If all the device types were supported by pytorch, then the code would look like this:
device_type = device_from_inputs(example_inputs).type
if device_type == CPU:
# compile for CPU
elif device_type == CUDA:
# compile for CUDA
elif device_type == D1:
# compile for D1
elif device_type == D2:
# compile for D2
The issue is that while the backend can compile pytorch models for those devices and run the compiled model on them, some of those devices are not supported by pytorch as ATEN devices (e.g. D1 and D2). So, I cannot create a tensor with those device types to be passed to torch.compile() . I want to know if the only solution is to implement the support for those target devices in pytorch to do that differentiation?