RFC Proposal for CUDA-Accelerated Dynamic Time Warping (DTW) Implementation in PyTorch

Hi PyTorch Community,

I’m excited to share an idea that I believe could significantly enhance the performance of time-series analysis in PyTorch. I’ve developed a CUDA-accelerated implementation of the Dynamic Time Warping (DTW) algorithm and would love to get your feedback.

Why CUDA-Accelerated DTW?

Dynamic Time Warping (DTW) is essential for measuring similarity between temporal sequences. However, it can be computationally intensive, especially with large datasets or real-time applications. By leveraging CUDA, we can significantly speed up DTW computations, making it feasible for high-frequency trading, real-time signal processing, and more.

Key Benefits

  1. Performance: Accelerates DTW computations by utilizing GPU parallelism.
  2. Scalability: Efficiently handles larger datasets, suitable for enterprise-level applications.
  3. Seamless Integration: Designed to integrate smoothly with existing PyTorch workflows.

Implementation Overview

The CUDA-accelerated DTW implementation includes:

  • CUDA kernels for forward and backward DTW computations
  • Custom PyTorch function integrating with autograd
  • Numba-accelerated CPU fallback for non-CUDA systems
  • High-level PyTorch module for easy integration

Here’s a glimpse of the CUDA kernel for forward computation:

@cuda.jit
def compute_dtw_cuda(D, max_i, max_j, R):
    b = cuda.blockIdx.x
    tid = cuda.threadIdx.x
    I = tid
    for j in range(1, max_j + 1):
        for i in range(1, max_i + 1):
            if I == i:
                r0 = R[b, i-1, j-1]
                r1 = R[b, i-1, j]
                r2 = R[b, i, j-1]
                R[b, i, j] = D[b, i-1, j-1] + min(r0, r1, r2)
        cuda.syncthreads()

I’m currently drafting an RFC (Request for Comments) to formally propose this feature. I would appreciate any feedback, suggestions, or use cases you think should be considered

2 Likes

Thank you for sharing your idea
You can use my code for short sequences.
But I want to change this for use it to compare huge number of sequences with a length of about 1000.

https://github.com/amirbroker/cudaDTW/tree/main

from numba import cuda

@cuda.jit
def dtwClc(seq, trg,dtw_matrix,lenEachSeq,lenEachTrg,lenDTW, INDList, dtw):
    tx = cuda.threadIdx.x
    ty = cuda.blockIdx.x

    if ty > lenDTW:
        return

    id1, id2 = INDList[ty]

    for eachELeachTaq in range(1, lenEachTrg + 1):
        # if lenEachSeq == tx:
        for eachELeachSeq in range(1, lenEachSeq + 1):
            if lenEachSeq == tx:
                cost = (seq[id1][0:lenEachSeq][eachELeachSeq - 1] - trg[id2][0:lenEachTrg][eachELeachTaq - 1])**2
                # cost = (seq[id1][0:lenEachSeq][eachELeachSeq - 1] - trg[id2][0:lenEachTrg][eachELeachTaq - 1])**2
                dtw_matrix[eachELeachSeq, eachELeachTaq] = cost + min(dtw_matrix[eachELeachSeq - 1, eachELeachTaq],    # Insertion
                                                                      dtw_matrix[eachELeachSeq, eachELeachTaq - 1],    # Deletion
                                                                      dtw_matrix[eachELeachSeq - 1, eachELeachTaq - 1] # Match
                                                                      )

            cuda.syncthreads()
        dtw[ty] = dtw_matrix[-1, -1] ** 0.5

######################
import numpy as np
from datetime import datetime

seq = cuda.to_device(np.random.rand(50,80))
trg = cuda.to_device(np.random.rand(100,100))

startTime = datetime.now()
lenSeq = seq.shape[0]
lenTrg = trg.shape[0]

lenEachSeq = seq.shape[1]
lenEachTrg = trg.shape[1]

INDList = cuda.to_device(np.array([(i,j) for i in range(lenSeq) for j in range(lenTrg)]))

dtw_matrix = np.zeros((lenEachSeq+1, lenEachTrg+1))
dtw_matrix[0, :] = np.inf
dtw_matrix[:, 0] = np.inf
dtw_matrix[0, 0] = 0

dtw = cuda.to_device(np.zeros([lenSeq * lenTrg]))

lenDTW = dtw.shape[0]

dtwClc[lenDTW, 1024](seq, trg,dtw_matrix,lenEachSeq,lenEachTrg,lenDTW, INDList, dtw)
dtw = dtw.copy_to_host()

print(dtw)

Thanks @amirbroker, Yes working on huge number of sequences with distributed systems of PyTorch.

I’ll attach glimpse of CUDA Kernel working on.
Will update this thread once done. Thanks:)

``

namespace at{
    namespace native{
        //cuda kernel to compute the forward DTW cost matrix
        template <typename scalar_t>
        __global__ void compute_dtw_cuda_kernel(
            const scalar_t* __restrict__ D, //Input distance matrix
            scalar_t* __restrict__ R, //cost matrix to br computed
            int64_t max_i, //Number of rows in the distance matrix
            int64_t max_j, //Number of columns in the distance matrix
            int64_t B
        ){
            int64_t b = blockIdx.x; //batch index
            int64_t tid = threadIdx.x; //thread index
            int64_t I = tid; //row index

            for(int64_t j = 1; j<=max_j; j++){
                for(int64_t i = 1; i<=max_i; i++){
                    if(I==i){
                        //retrieve values from the previous cells in the cost matrix
                        scalar_t r0 = R[b *(max_i + 1)*(max_j + 1)+(i - 1)*(max_j + 1)+(j -1)];
                        scalar_t r1 = R[b *(max_i + 1)*(max_j + 1)+(i - 1)*(max_j + 1)+j];
                        scalar_t r2 = R[b*(max_i + 1)*(max_j + 1)+i*(max_j + 1)+(j-1)];

                        //compute the cost for the current cell
                        R[b*(max_i + 1)*(max_j + 1)+i*(max_j + 1)+j] = D[b*max_i*max_j+(i-1)*max_j+(j-1)]+min(r0, min(r1, r2));
                    }
                }
                __syncthreads(); //synchronize threads within a block
            }
        }

        //cuda kernel to compute the backward DTW Gradients
        template <typename scalar_t>
        __global__ void compute_dtw_backwards_cuda_kernel(
            const scalar_t* __restrict__ D, // Input distance matrix
            const scalar_t* __restrict__ R, // forward cost matrix
            scalar_t* __restrict__ E, //gradient matrix to be computed
            int64_t max_i, //number of rows in the distance matrix
            int64_t max_j, //number of columns in the distance matrix
            int64_t B //batch size
        ){
            int64_t b = blockIdx.x; // batch index
            int64_t tid = threadIdx.x; //thread index
            int64_t I = tid; //row index

            for(int64_t j = max_j; j>0; j--){
                for(int64_t i = max_i; i>0; i--){
                    if(I==i){
                        //compute the gradients for the current cell
                        scalar_t a =(R[b * (max_i + 1)*(max_j + 1)+(i+1)*(max_j + 1)+j] - R[b * (max_i + 1)*(max_j + 1)+i*(max_j + 1) + j] - D[b * max_i * max_j + max_j * i + j]);
                        scalar_t b =(R[b * (max_i + 1)*(max_j + 1)+i*(max_j + 1)+(j + 1)] - R[b*(max_i + 1)*(max_j + 1)+i*(max_j + 1) + j] - D[b * max_i * max_j + i * max_j + (j+1)]);
                        scalar_t c =(R[b * (max_i + 1)*(max_j + 1)+(i+1)*(max_j + 1)+(j+1)] - R[b*(max_i + 1)*(max_j + 1)+i*(max_j + 1)+j] - D[b*max_i*max_j+(i+1)*max_j+(j+1)]);

                        //update the gradient matrix
                        E[b*(max_i + 1) * (max_j + 1) + i * (max_j + 1) + j] = E[b * (max_i + 1) + (i + 1) * (max_j + 1) + j] * a + E[b * (max_i + 1) * (max_j + 1) + i * (max_j + 1) + (j + 1)] * b + E[b * (max_i + 1) + (max_j + 1) * (i+1) * (max_j + 1) + (j + 1)] * c;
                    }
                }
                __syncthreads(); //Synchronize threads within a block
            }
        }
        //Function to launch the forward DTW CUDA kernel
        Tensor dtw_forward_cuda(const Tensor& D){
            auto D_ = D.continous(); //Ensure contiguous memory layout
            auto dev = D.device().index(); //get device index
            auto dtype = D.scalar_type(); //get data type

            int64_t B = D.size(0); //batch size
            int64_t N = D.size(1); //number of rows
            int64_t M = D.size(2); //number of columns
            int threads_per_block = max(N, M); //number of threads per block

            //initialize the cost matrix with infinity
            auto R = at::empty({B, N+1, M+1}, D.options().dtype(dtype).device(D.device()));
            R.fill_(std::numeric_limits<float>::infinity());
            R.select(1, 0).fill_(0); // set the first row to zero
            R.select(2, 0).fill_(0); // set the first column to zero

            // launch the CUDA kernel for forward DTW
            AT_DISPATCH_FLOATING_TYPES(D.scalar_type(), "dtw_forward_cuda", ([&] {
                compute_dtw_cuda_kernel<scalar_t><<<B, threads_per_block>>>(
                    D_.data_ptr<scalar_t>(),
                    R.data_ptr<scalar_t>(),
                    N,
                    M,
                    B
                );
            }));

            //return the last element of the cost matrix (DTW distance)
            return R.select(1, N).select(1, M);
        }
1 Like