DTensor random RNG state support for non-CUDA backends

DTensor random ops mesh support for a backend is checked with is_rng_supported_mesh, which checks the presence of hasattr(device_handle, “set_rng_state”).

For CPU RNG state, this seems to return False. However, If a backend uses CPU RNG state and has a set_rng_state() implemented, it sets is_rng_supported_mesh to be True and DTensor random mechanism tries to use OffsetBasedRNGTracker. The seed and offset based APIs assume the RNG state to be a CUDA like offset based. CPU RNG state doesn’t work and fails as shown below -

[rank0]:   File ".../.local/lib/python3.10/site-packages/torch/distributed/_tensor/random.py", line 176, in _distribute_region
[rank0]:     old_offset = self.get_offset("parallel-rng")
[rank0]:   File ".../.local/lib/python3.10/site-packages/torch/distributed/_tensor/random.py", line 195, in get_offset
[rank0]:     return int(offset_tensor.item())
[rank0]: RuntimeError: a Tensor with 631 elements cannot be converted to Scalar

Is there a plan to enhance the DTensor to use CPU like RNG state, or is it possible to mark is_rng_supported_mesh to be false for backends that use CPU RNG state and not the offset based RNG state?

1 Like

Hey!

cc @wanchaol in case he’s seeing this.
This kind of request might be best as an issue on github though as this is where we track features and improvements usually!

Thanks @albanD . I have submitted DTensor RNG state for non CUDA backends · Issue #138329 · pytorch/pytorch · GitHub for this.