Implementing OpenCL backend for pytorch

Hello,

I had implemented recently a basic set of deep learning operations and initial training/inference library.

Also it is fairly new it already outperforms PlaidML and Caffe/OpenCL by 150-200% in tested networks (alexnet,resnet, vgg,mobilenet) in both training and inference and AMD and nVidia GPUS. It also gives ~50% to 70% performance of native cuda+cudnn/hip+miopen on amd gpus.

I want to start working on OpenCL (out-of-tree) backend for PyTorch.

I implemented both GEMM and Winograd based convolutions and other multiple DL operators (also it is only beggining)

Unlike ROCm it runs on Windows as well and should support RDNA and APUs (also I handn’t tested it)

Is there any initial implementation of OpenCL backend or some kind of template backend so I can start more easily with the task?

I really think having OpenCL support is valuable and the fact that only project that supported it Caffe - dead and PlaidML killed by killing multi-backend Keras.

Thanks

2 Likes

Maybe some minimal backend that allows to do stuff like:

y=x1+x2
y.backward()

So it can be started from there?

In my view, the difficulty in bootstrapping a new architecture is twofold:

  • PyTorch has quite some infrastructure (the dispatcher, see e.g. ezyang’s blog post and podcast),
  • the operator coverage (where you have a head start, apparently).

You likely want to tackle the first before the second becomes your main problem.

If you wanted to pull this off (it will be quite some undertaking), you could start with your own build and do

torch.ones(5,5, device="opencl")

This gives you

RuntimeError: 0INTERNAL ASSERT FAILED at "../c10/core/TensorOptions.h":655, please report a bug to PyTorch. This is a grandfathered Caffe2 device type opencl, it shouldn't ever convert to a DispatchKey.  File a bug describing what you were doing if you think this is in error.

After you fix this, you’ll likely bump into the next. :slight_smile: You could also take inspiration from the more recent vulkan backend (which, as far as I understand, is special purpose, but recently and also eyes APUs etc.).

This would be the first thing to resolve. Once you have that, you could tackle simple ops (y = x1 + x2).
Quite likely the autograd support might not be as device dependent (but I didn’t try, obviously).

Best regards

Thomas

P.S.: Disclaimer: I am just a random person on the internet and not anyone who has any say whether PyTorch would accept an OpenCL backend if it were there.

1 Like

Thanks. I’ll read the blog posts. From first glance looks very interesting (and complicated)

My first DL framework was Caffe that I still like a lot due to its highly readable and easy to modify C++ code (and of course OpenCL support)

In any case, dispatcher and other technical things are complicated in terms of system but actually simple in comparison to optimized DL kernels.

For example I hadn’t found a single open source general purpose implementation of Winograd algorithm either in CUDA or OpenCL (ROCm’s are actually binary blows) and Intel ones are highly tighten to Intel architecture. Finally I found a parer in 2020 that described how GPU implementation of Winograd should look like.

Even GEMM based convolutons aren’t very good - clBlast implements one but its performance very poor (and implements only FWD propogation)

So complex is relative thing :slight_smile:

1 Like

If you just want to add a new “device” so that you can then register implementations for it in the dispatcher for your implementation. This is the kind of change that people are doing: https://github.com/pytorch/pytorch/pull/58248 you don’t need the whole thing I think though for what you want here.

2 Likes

I started going over this tutorial: Extending dispatcher for a new backend in C++ — PyTorch Tutorials 1.9.1+cu102 documentation

I created custom function + backward for CPU so looks ok - I can link, call and run custom fwd/bwd on cpu tensor.

Now I want to extend for private use so I start prototyping opencl backend. What is not clear to me is following:

How do I create/copy to a tensor for privateuse dispatch key.

In python I can call torch.randn(10).to('cuda') now how do I do it for private key? What do I miss?

Or I need to modify pytorch source for out-of-tree backend?

Ok… here the progress:

TL;DR: I managed to run inference of alexnet using OpenCL/DLPrimitives based pytorch backend!

Details:

  1. I changed mapping of opencl device to PirvateUse1 dispatch key just to be able to do anything. Currently it is only change I needed to start working on out-of-tree backend. I found that decpide the suggestion to use these dispatch key I need to have device mapped to it and it is impossible to do without modifications of pytorch sources: Set temporary opencl device to PrivateUse1 dispatch key · artyom-beilis/pytorch@eb74af1 · GitHub
  2. Another item that is missing from this manual about out-of-source backends Extending dispatcher for a new backend in C++ — PyTorch Tutorials 1.9.1+cu102 documentation is need to implement c10::impl::DeviceGuardImplInterface and register it via c10::impl::DeviceGuardImplRegistrar

Now I implemented only handful of ops and mostly for forward computations: GitHub - artyom-beilis/pytorch_dlprim: DLPrimitives-OpenCL out of tree backend for pytorch but I managed to do computations and get correct result on pretrained alexnet.

$ python validate_network.py --model alexnet --device cuda --benchmark *.ppm
cat.ppm,281,tabby,0.249897,-0.674675,-2.994513,-1.568204,-2.399394,3.196111,-3.784611,...
dog.ppm,207,golden retriever,-4.164610,-5.017385,3.193234,-6.757652,-1.752393,-1.439135,-3.598867,...
parrot.ppm,87,African grey,-0.318788,5.249665,-6.590664,-4.953464,-3.192156,2.550208,-2.042364,...

$ python validate_network.py --model alexnet --device opencl:1 --benchmark *.ppm 
Accessing device #1:GeForce GTX 960 on NVIDIA CUDA
cat.ppm,281,tabby,0.249900,-0.674674,-2.994513,-1.568204,-2.399395,3.196111,-3.784608,...
dog.ppm,207,golden retriever,-4.164612,-5.017380,3.193236,-6.757651,-1.752393,-1.439135,-3.598869,...
parrot.ppm,87,African grey,-0.318790,5.249666,-6.590665,-4.953461,-3.192155,2.550208,-2.042365,...

Performance not brilliant but not horrible (also this net is way too simple): GTX 960, alexnet batch size 16, image 224x224

  • Pytorch Cuda/CUDNN: 11.685 ms
  • Pytorch OpenCL/DLPrimitives: 23.966 ms
  • DLPrim - microframework: 22.401 ms
  • Caffe/CuDNN: 16.1812 ms
  • Caffe/OpenCL: 41.072 ms
  • Caffe/OpenCL+DLPrimitives: 28.618 ms
  • Keras/CuDNN: 23.341 ms
  • Keras/PlaidML: 44.041 ms

Now, one of the issues that I currently have is synchronous execution that gives significant penalty for every operation. I need to understand some stuff and for that I’ll open another thread since it isn’t directly related to opencl

Small updated: I implemented GPU memory caching + asynchronous execution and got performance results virtually identical for my static graph dlprimitives execution.

Now it works efficiently for all GPUs I tested AMD, 6600XT NVidia 960 and Intel GPUs 530.
Also I fixed pytorch benchmark that by accident didn’t include copy to gpu time and now run times on 960 are ~15ms on pytorch cuda/cudnn 960 and ~22ms on dlprimitives

2 Likes

And now some more progress:

And performance:

Benchmarks

All benchmarks done on gtx 960/4G to get comparison to native cuda speed.

Test

Test includes copy of data to/from device and forward calculations

Framework alexnet resnet18 resnet50 vgg16 mobilenet
pytorch/cuda 15.253 38.745 114.348 169.038 46.110
pytorch/opencl 22.989 50.272 167.050 258.751 82.044
dlprimitives 22.688 49.193 158.789 238.802 82.080
keras/tf2-cuda 29.104 74.215 161.704 158.084 88.851
keras/plaidml 43.004 91.533 - - 45.693

Full Train

Train includes - io to/from device, zero gadients, forward, backward and optimizer update step. Adam used as optimizer.

Framework alexnet resnet18 resnet50 vgg16 mobilenet
pytorch/cuda 107.108 129.456 388.951 N/A 177.434
pytorch/opencl 147.814 213.319 651.216 N/A 382.590
dlprimitives 106.033 198.092 605.541 1107.756 344.599
keras/tf2-cuda 90.005 183.447 501.362 550.063 322.416
keras/plaidml 222.166 507.116 - - 571.438
  • vgg16 batch 16 failed to run to to lack of memory on pytorch.
  • some setups with plaidml not tested due to lack of performance/memory

Looks very nice :slight_smile:

5 Likes

Just a random guy, buy thanks for all the efforts going into implementing an OpenCL backend! I’m definitely going to try it out.

1 Like

This is a valuable project that should be supported by Intel and more Nvidia competitors ASAP!

2 Likes

Agreed :slight_smile:

On more serious note - I’d really love to see some community contribution. I’m very limited in the time.

Many things can be trivially implemented in terms of reduction/pointwise operations and thus not even OpenCL/GPU programming knowledge is required

They both want to create vendor lock in on their gpus.

Honestly if we can get this working its gona make life much easier on intel hardware. As it stands its a proper nightmare

Would also mean you can build a project that works on all hardwares which is auch a nice solution to the hardware curse

They both want to create vendor lock in on their gpus.

Yes and no. I assume both Intel and AMD try to get most valuable solution for minimal money.

AMD went down the road of implementing HIP as Cuda replacement (huge mistake IMHO) but at least they released MIOpen OpenCL version. But unfortunately it isn’t really open-source since some kernels just go in binary formats. Additionally it is limited to Linux/ROCm only.

Intel has they OpenDNN and at some point I wanted to integrate it/use it in my backend but I discovered that my own kernels I have written work actually faster than their own kernels.

Why? Because they decided to optimize their code for Channel last only while PyTorch, ONNX and many other frameworks actually use channel first: Channel First Convolution Performance on OpenCL driver · Issue #1194 · oneapi-src/oneDNN · GitHub

I don’t now how much of it ignorance and how much of it try to do what possible with very limited resources (and yes, neither AMD nor Intel have team/investment even comparable to what nVidia does)

And most of industry around just does not care and lives with what nVidia has to propose because other alternatives are far from competitive.

I hope community can join the effort. Because building it all is really tough. Too many cases, kernels and other things.