Custom kernels Intel XPU backend in LibTorch C++ API

I am using the C++ API of LibTorch 2.5.1 (to be clear here I am not using any Python code) and have implemented a few self-written CUDA and HIP kernels, e.g.,

/**
   @brief Compute Greville abscissae
*/
template <typename real_t>
__global__ void
greville_kernel(torch::PackedTensorAccessor64<real_t, 1> greville,
                           const torch::PackedTensorAccessor64<real_t, 1> knots,
                           int64_t ncoeffs, short_t degree, bool interior) 
{
  for (int64_t k = blockIdx.x * blockDim.x + threadIdx.x;
       k < ncoeffs - (interior ? 2 : 0); k += blockDim.x * gridDim.x) {
    for (short_t l = 1; l <= degree; ++l)
      greville[k] += knots[k + (interior ? 1 : 0) + l];
    greville[k] /= real_t(degree);
  }
}

that I call from within my regular C++ code as follows

// CUDA
int blockSize, minGridSize, gridSize;
cudaOccupancyMaxPotentialBlockSize(&minGridSize, &blockSize, (const void *)greville_kernel<real_t>, 0, 0);
gridSize = (ncoeffs_[j] + blockSize - 1) / blockSize;
greville_kernel<<<gridSize, blockSize>>>(greville, knots, ncoeffs_[j], degrees_[j], interior);

// HIP
int blockSize, minGridSize, gridSize;
static_cast<void>(hipOccupancyMaxPotentialBlockSize(&minGridSize, &blockSize, (const void *)greville_kernel<real_t>, 0, 0));
gridSize = (ncoeffs_[j] + blockSize - 1) / blockSize;
greville_kernel<<<gridSize, blockSize>>>(greville, knots, ncoeffs_[j], degrees_[j], interior);

The code is implemented as header-only library (the CUDA/HIP kernels are implemented in a regular hpp file) and the main application is compiled with nvcc and hipcc, respectively.

I managed to compile my code with Intel GPU support enabled by following the installation instructions given here Getting Started on Intel GPU — PyTorch 2.5 documentation and pointing CMake to the libtorch library and header files in the Python site-packages directory. My code works fine except for the custom kernels.

I would appreciate some help in implementing the above (and similar) custom kernels in SYCL and calling them from the C++ code. I am familiar with CUDA/HIP programming but not yet with SYCL

Thanks for the message. Would it be ok to make it public a public post so that more people can help?

1 Like

Please do make it public. Any help is appreciated.

···

Am 09.01.2025 um 19:42 schrieb albanD via PyTorch Developer Mailing List notifications@pytorch1.discoursemail.com:


You don’t often get email from notifications@pytorch1.discoursemail.com. Learn why this is important
| |

  • | - | - |

| albanD
January 9 |

  • | - |

Thanks for the message. Would it be ok to make it public a public post so that more people can help?


Visit Message or reply to this email to respond to albanD, guangyey.

To unsubscribe from these emails, click here.

1 Like

template<typename real_t>
struct GrevilleKernel
{
  GrevilleKernel(torch::PackedTensorAccessor64<real_t, 1> greville,
                 const torch::PackedTensorAccessor64<real_t, 1> knots,
                 int64_t ncoeffs, short_t degree, bool interior)
    : greville(greville), knots(knots), ncoeffs(ncoeffs), degree(degree), interior(interior) {}

  void operator()(sycl::nd_item<1> item) const {
    // This function will be comipled into a gpu device kernel.

    // item.get_group(0) is blockIdx.x
    // item.get_local_range(0) is blockDim.x
    // item.get_local_id(0) is threadIdx.x
    // item.get_group_range(0) is gridDim.x
    int64_t k = item.get_group(0) * item.get_local_range(0) + item.get_local_id(0);
    for (; k < ncoeffs - (interior ? 2 : 0); k += item.get_local_range(0) * item.get_group_range(0)) {
      for (short_t l = 1; l <= degree; ++l)
        greville[k] += knots[k + (interior ? 1 : 0) + l];
      greville[k] /= real_t(degree);
    }
  }

private:
  torch::PackedTensorAccessor64<real_t, 1> greville;
  const torch::PackedTensorAccessor64<real_t, 1> knots;
  int64_t ncoeffs;
  short_t degree;
  bool interior;
};


constexpr int dim = 1;
constexpr int blockSize = 256;
// All the work items in total to be launched. For SYCL programming, it does not provide the gridSize counter part directly.
// However, we can calculate the grid_size by dividing the global_range by local_range.
auto gridSize = (ncoeffs_[j] + blockSize - 1) / blockSize;
sycl::range<dim> global_range(gridSize * blockSize);
// local_range is sort of like blockSize.
sycl::range<dim> local_range(blockSize);

typedef GrevilleKernel<real_t> greville_kernel_t;
// Create sycl kernel functor.
auto greville_kernel = greville_kernel_t(greville, knots, ncoeffs, degree, interior);
// The kernel functor is passed to the sycl_kernel_submit function.
auto cgf = [&](::sycl::handler& cgh) {
  cgh.parallel_for<greville_kernel_t>(
      sycl::nd_range<dim>(global_range, local_range),
      greville_kernel);
};

// at::xpu::getCurrentSYCLQueue() is a function that returns the current SYCL queue.
const sycl::queue& q = at::xpu::getCurrentXPUStream().queue();
// Submit the kernel to the queue.
q.submit(cgf);
1 Like

@mmoelle1 , so far, SYCL spec has not defined cudaOccupancyMaxPotentialBlockSize-like feature yet. However, users can leverage other SYCL runtime APIs to implement a similar feature. Back to the detailed kernel implementation, we pre-fined some heuristic numbers. By the way, SYCL is more OO orientation. The line of code may not be the same as CUDA.

Regarding the libtorch, do you mean the libtorch package? For the torch package, we have exposed the runtime-related C++ head files. There should be no feature gap. By the way, we are working on the C++ SYCL extension enabling. xpu: support sycl with torch.utils.cpp_extension APIs by dvrogozh · Pull Request #132945 · pytorch/pytorch · GitHub. FYI

Here are some links for your reference and better understanding.