Offsets on mps kernel indexing

I am trying to understand kernel indexing for MPS backend in pytorch. I have two questions:

  1. Though kernel_index_offsets recieves num_offsets argument, it is not using it. Am I overlooking something here?
  1. Why is nOffsets assigned 3 in the following line:

Apologies if this question is out of place here. I am interested in contributing to pytorch (in this General MPS op coverage tracking issue · Issue #77764 · pytorch/pytorch · GitHub) and working towards getting an understanding of the codebase.

Hi,

For your questions:

  1. Yes it is indeed not being used as this kernel assumes the number of offsets being 3 (you can see the data type of data_offsets, which is uint3).

  2. The kernel was originally designed for handling binary operations, i.e. binary inputs, so the offsets represent input(self), other, and output, respectively.

1 Like