Following this discussion https://github.com/pytorch/pytorch/issues/488#issuecomment-1209498491
I have two questions:
Does private use device is implemented in latest version or in nightly only
regarding:
We can build a very simply torch_function
mode to translate any opencl string into privateuseone automatically before it hits the backend.
Can you please elaborate. Is this something already implemented or need to be implemented on pytorch side? Or there is something I as backend developer do?
Thanks a lot!
albanD
August 18, 2022, 11:17pm
2
Hey!
Here is the full example I was talking about there: GitHub - bdhirsh/pytorch_open_registration_example: Example of using pytorch's open device registration API
This should clarify 2) and your other questions in that issue.
For 1) I think that everything is already there in 1.12.1 yes.
cc @bdhirsh was there any super recent bug fix that is required?
1 Like
Nope, I think everything you need should be in the 1.12.1 branch.
1 Like
Hi I saw this line in the example - something relatively new:
torch.register_privateuse1_backend('foo')
Can you please elaborate what it does?
Edit: now I see it is torch.utils.rename_privateuse1_backend
and looks it does what I thought it does allows to use custom device name neat!
Looks like it is what I was looking for. When is it going to stable/released?
1 Like
albanD
November 11, 2022, 7:46pm
5
I’m afraid it will only be on the next release and didn’t make it to 1.13
It is very nice to have feature, but does not limit the functionality, I currently work well without it. Just rename ocl->privateuseone at parameters level :-)
albanD
November 11, 2022, 9:36pm
7
Awesome!
Feel free to add something like
if hasattr(torch.utils, "rename_privateuse1_backend"):
torch.utils.rename_privateuse1_backend("ocl")
so that recent users have it but you can shit it for any version of PT.
Yep, I do something similar, also hasattr is better:
args = parser.parse_args()
torch.manual_seed(args.seed)
device = args.device
if device.find('ocl')==0:
if os.name == 'nt':
torch.ops.load_library(r"build\pt_ocl.dll")
else:
torch.ops.load_library("build/libpt_ocl.so")
try:
torch.utils.rename_privateuse1_backend('ocl')
except:
device = device.replace('ocl','privateuseone')
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if device!='cpu':
cuda_kwargs = {'num_workers': 1,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
1 Like