Implementing OpenCL backend for pytorch

Hello,

I had implemented recently a basic set of deep learning operations and initial training/inference library.

Also it is fairly new it already outperforms PlaidML and Caffe/OpenCL by 150-200% in tested networks (alexnet,resnet, vgg,mobilenet) in both training and inference and AMD and nVidia GPUS. It also gives ~50% to 70% performance of native cuda+cudnn/hip+miopen on amd gpus.

I want to start working on OpenCL (out-of-tree) backend for PyTorch.

I implemented both GEMM and Winograd based convolutions and other multiple DL operators (also it is only beggining)

Unlike ROCm it runs on Windows as well and should support RDNA and APUs (also I handn’t tested it)

Is there any initial implementation of OpenCL backend or some kind of template backend so I can start more easily with the task?

I really think having OpenCL support is valuable and the fact that only project that supported it Caffe - dead and PlaidML killed by killing multi-backend Keras.

Thanks

1 Like

Maybe some minimal backend that allows to do stuff like:

y=x1+x2
y.backward()

So it can be started from there?

In my view, the difficulty in bootstrapping a new architecture is twofold:

  • PyTorch has quite some infrastructure (the dispatcher, see e.g. ezyang’s blog post and podcast),
  • the operator coverage (where you have a head start, apparently).

You likely want to tackle the first before the second becomes your main problem.

If you wanted to pull this off (it will be quite some undertaking), you could start with your own build and do

torch.ones(5,5, device="opencl")

This gives you

RuntimeError: 0INTERNAL ASSERT FAILED at "../c10/core/TensorOptions.h":655, please report a bug to PyTorch. This is a grandfathered Caffe2 device type opencl, it shouldn't ever convert to a DispatchKey.  File a bug describing what you were doing if you think this is in error.

After you fix this, you’ll likely bump into the next. :slight_smile: You could also take inspiration from the more recent vulkan backend (which, as far as I understand, is special purpose, but recently and also eyes APUs etc.).

This would be the first thing to resolve. Once you have that, you could tackle simple ops (y = x1 + x2).
Quite likely the autograd support might not be as device dependent (but I didn’t try, obviously).

Best regards

Thomas

P.S.: Disclaimer: I am just a random person on the internet and not anyone who has any say whether PyTorch would accept an OpenCL backend if it were there.

Thanks. I’ll read the blog posts. From first glance looks very interesting (and complicated)

My first DL framework was Caffe that I still like a lot due to its highly readable and easy to modify C++ code (and of course OpenCL support)

In any case, dispatcher and other technical things are complicated in terms of system but actually simple in comparison to optimized DL kernels.

For example I hadn’t found a single open source general purpose implementation of Winograd algorithm either in CUDA or OpenCL (ROCm’s are actually binary blows) and Intel ones are highly tighten to Intel architecture. Finally I found a parer in 2020 that described how GPU implementation of Winograd should look like.

Even GEMM based convolutons aren’t very good - clBlast implements one but its performance very poor (and implements only FWD propogation)

So complex is relative thing :slight_smile:

If you just want to add a new “device” so that you can then register implementations for it in the dispatcher for your implementation. This is the kind of change that people are doing: https://github.com/pytorch/pytorch/pull/58248 you don’t need the whole thing I think though for what you want here.

1 Like