Improve the extension with `PrivateUse1` for custom device

Hi, I am working on to improve the extension with PrivateUse1 for custom device, so that PyTorch can support custom device better.
And I have give some PR to do this, such as
97420, 93920
96188 and some small fix to support custom device. And I will give more to do this in the future.
During this time I improve the extension with custom device, I found some code can be optimized and improved:

  1. For AMP, the GradScaler related code is at torch/cuda/amp, but I think for different device (cuda/xpu, etc), the strategy to update scale should be basically the same.
    So it is better to put these code to torch/amp, and use the op _amp_foreach_non_finite_check_and_unscale_ to determine if there is overflow. And if there are some special for the device, we can inherit the GradScaler defined in torch/amp.
  2. For AMP in C++, there is a special DispatchKey for every device type, CUDA with AutocastCUDA, XPU with AutocastXPU and so on. I think the DispatchKey is just to indicate which operator uses which strategy to autocast, and it has nothing to do with the device type, it is just a strategy,. So we can use a unified key AUTOCASTto configure some basic operator strategies. Then for the custom device, we can override these operator strategies according to our own characteristics.

It is ok sounds to you ?

Thanks for sending these details.

For 1. I think we can consolidate this for sure. cc @janeyx99 @crcrpar who has been looking at GradScaler recently.
We will need to be careful about backward compatibility when moving public APIs around but it should be do-able in this case.

For 2. I’m not super familiar how the current cpu/xpu/cuda strategies differ beyond being targeted to different dtypes. We can definitely have an Autocast alias key to be able to more easily register strategies for all devices and then we can use Autocast* regular keys to override it.
Here again we would need to be careful not to change the current behavior for current devices while changing this implementation (beyond adding support for new devices).
Also note that using torch.library, it would be quite easy to register all the strategies directly from python.

Some more device things to consider for the GradScaler in addition to the amp_foreach_non_finite_check_and_unscale op:

  • There is a MultiDeviceReplicator that lazily serves copies of found_inf and grad_scale to support the foreach op. (This is also only XLA/CUDA)
  • The other amp op, _amp_update_scale, is CUDA only, though if scale is allowed to remain on cuda, this support may be optional

Yeah, thanks for your reply.
For 1, in GradScaler there are some code related to device. And I have analyzed these code, in torch/cuda/amp, the most code is generic, such as the funcs namedget_**/set_**, and the func named scale** only have severl lines code related to device, so these func and code can be put in torch/amp. And for device XLA/CUDA, we can inherit and override to do these special thing for device.
And I found there have a PR to support CPU 98926, and it have refactor something. And we can improve it futrue, only do some general operations in torch/amp, and device related func we can inherit and override and put it at torch/cuda/amp or torch/cpu/amp.

For 2, I remember that @bdhirsh has do something with it according to your comment
So I want to continue this job, it is ok sounds to you?

And nit, could you give me a permission to join pytorch slack? I have applied to join it but without reply for some days. My email is endswith @albanD

1 Like

Ho it is not per-backend today? :o
After checking, it does have a AutocastPrivateUse1 so that proposal above still works.
Thinking about this a bit more, cleaning up the fact that it is not a per-backend functionality will be a bit tricky and not much benefit for your current goal (even though it would simplify the PT codebase). So we should definitely do this, but it is independent from your ask.

For you problem at hand, the best case scenario in my mind would be for us to make it an opt-in feature just like the methods: you can optin for your backend to have the cuda autocast behavior (and then you can tweak it). But given how the registration works, it might be very tricky to make it opt-in. If so, having it as the default may be ok.

For the python side of things, I am all onboard moving this to be device generic in torch/amp.
So anything that will not break backward compatibility sounds like a good plan to me :smiley:

If you’re interested, my initial attempt at making autocast a generic per-backend-functionality key was here: I ran into some interesting CI failures and abandoned the PR.

I agree with Alban - making autocast a per-backend-functionality key would be a nice bit of cleanup (if you want to take it on I"m happy to review!). But it definitely isn’t necessary for our goals with PrivateUse1. There’s an AutocastPrivateUse1 dispatch key today that we can continue to work off of.

yeah, for the AutocastPrivateUse1 related, I will continue to analyze in depth. for the python side, thanks for your support, and I will do something to optimize with it based on 98926.

1 Like

thanks for you reply. I will carefully analyze your work and code, it may take a long time to do this.

And now, I have give 2 PRs to refactor the AMP, Refactor gradscaler by heidongxianhua · Pull Request #99301 · pytorch/pytorch · GitHub and improve macro with AMP by heidongxianhua · Pull Request #99285 · pytorch/pytorch · GitHub
For 99301 , I refactor the python API as what I said.
For 99285 , I refactor the macro so that we can use it for custom device.

@albanD @bdhirsh Could we have a look ? :slightly_smiling_face: