Greetings! First time poster in this forum, after incorrectly starting the topic in the user forum.
I won’t repeat those postings here (since I hope you will read them there), but in a nutshell, the MLIR Sparsifier team is very interested in connecting the torch.sparse tensors with the MLIR sparse tensor types. A very first step towards this would be propagating sparsity information in the FX traced graph (without lowering the ops into sparse ops yet, or the arguments to their actual 1:N implementation arrays).
I made a very quick and dirty prototype for this that generates something like this for the example given in the original posting.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[64, 64]:torch.sparse_csr"): # ADDED!
# File: biknet.py:27, code: return x.sum()
sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_); l_x_ = None
return (sum_1,)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='l_x_'),
target=None,
layout=torch.sparse_csr) # ADDED!
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='sum_1'),
target=None)
])
This will hopefully enable me to prototype further in torch-mlir so I can report back here if this is a viable approach, and then further work on the actual feature request.