POC for properly supporting torch.dtype in the JIT

For your viewing pleasure, here is a patch that could be used as a POC what it takes to properly support dtypes (edit in the meantime, I fixed the errors uncovered in ONNX, CUDA fuser and TensorExpr fuser, so it passes most of test_jit*.py, I replaced the link with one to a tree off PyTorch master with everything):

I should say that I have no ideas about how to deal with the backward compat breaking nature of the change other than bumping some serialization version and refuse to work with the old stuff (how would I do the versioning for my own branch). That likely makes it unattractive to PyTorch.
Any advice?

Probably some cleanups could be had from not needing to pretend around torch.dtype.

Best regards

Thomas

This is super cool! I think in principle we’re interested in this (it fixes a long-running hole in the type system), but I agree it may be tedious to merge. Can you detail the nature of the BC break? I don’t see that much user code modified in the PR.

Thank you for your interest and you reply!

So there are two and a half things (I am aware of) that break:

  • post-patch cannot load dtype-using scripts from pre-patch (it’ll fail with schema mismatches for things for functions that stopped pretending to take ints). This could be mitigated (for things that only use the dtype within a given function as opposed to passing it around) by adding conversions on load whenever we find a non-matching schema that would match if dtype were int.
  • things passing int as dtype or expecting to get an int from the JIT when it now returns dtype,
  • user-defined functions that have been coded to take ints because they want dtypes.

The first one is likely the one I’d be most concerned about and could be mitigated by post-processing on load. I would not expect dtype crossing the “JIT<->Python boundary” as in the latter two cases to be as common.

Best regards

Thomas