Connecting PyTorch sparse tensors with MLIR

Greetings! First time poster in this forum, after incorrectly starting the topic in the user forum.

I won’t repeat those postings here (since I hope you will read them there), but in a nutshell, the MLIR Sparsifier team is very interested in connecting the torch.sparse tensors with the MLIR sparse tensor types. A very first step towards this would be propagating sparsity information in the FX traced graph (without lowering the ops into sparse ops yet, or the arguments to their actual 1:N implementation arrays).

I made a very quick and dirty prototype for this that generates something like this for the example given in the original posting.

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, l_x_: "f32[64, 64]:torch.sparse_csr"):   # ADDED!
            # File: biknet.py:27, code: return x.sum()
            sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_);  l_x_ = None
            return (sum_1,)
           
Graph signature: ExportGraphSignature(
  input_specs=[
      InputSpec(
           kind=<InputKind.USER_INPUT: 1>,
           arg=TensorArgument(name='l_x_'),
           target=None,
           layout=torch.sparse_csr)       # ADDED!
  ],
  output_specs=[
     OutputSpec(
         kind=<OutputKind.USER_OUTPUT: 1>,
         arg=TensorArgument(name='sum_1'),
         target=None)
 ])

This will hopefully enable me to prototype further in torch-mlir so I can report back here if this is a viable approach, and then further work on the actual feature request.

mlir_sparsifier_name

3 Likes

I have posted a quick-and-dirty prototype “PR” that implements part of the requested feature (getting the type into the forward() parameter list in the FX Graph). This PR is of course not meant to be submitted. But hopefully some of the core developers can give some guidelines or assistance in getting the idea into a production-quality implementation.

In the meantime, sparse support in torch-mlir part is making great progress. With a simple “wrapper” exporter (that builds the FX graph for dense and then annotates sparse arguments afterwards), something like

class MatMulNet(torch.nn.Module):

        def __init__(self):
            super(MatMulNet, self).__init__()

        def forward(self, x, y):
            return torch.matmul(x, y)

m = export_and_import(MatMulNet(), A_coo, B_dense)

actually goes through the PyTorch graph exporter all the way down to something that can be further processed by MLIR.


#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }>
module {
  func.func @main(%arg0: !torch.vtensor<[64,64],f32,#sparse>, 
                  %arg1: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> {
    %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[64,64],f32,#sparse>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32>
    return %0 : !torch.vtensor<[64,64],f32>
  }
}

Minor update, we now have sufficient machinery in torch-mlir to run a simple PyTorch model “end-to-end” for sparse tensors as input. Take for example, the following code that uses MatMulNet. Then we get the same results when running with the normal PyTorch engine vs. torch-mlir execution (operating on the underlying numpy arrays):

    net = MatMulNet()
    a = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 0, 0, 0, 0],
                      [0, 0, 2, 0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 0, 0, 0, 3],
                      [0, 0, 0, 0, 0, 0, 0, 4],
                      [0, 0, 0, 0, 0, 0, 0, 5]],dtype=torch.float32)
    sparse_input = a.to_sparse_csr()
    res0 = net(a, a)
    res1 = net(sparse_input, a)
    res2 = sparse_jit(net, sparse_input, a)   # uses TORCH-MLIR +sparse

all yield the following numpy data

[[ 1.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  4.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0. 15.]
 [ 0.  0.  0.  0.  0.  0.  0. 20.]
 [ 0.  0.  0.  0.  0.  0.  0. 25.]]