DTensor random ops mesh support for a backend is checked with is_rng_supported_mesh, which checks the presence of hasattr(device_handle, “set_rng_state”).
For CPU RNG state, this seems to return False. However, If a backend uses CPU RNG state and has a set_rng_state() implemented, it sets is_rng_supported_mesh to be True and DTensor random mechanism tries to use OffsetBasedRNGTracker. The seed and offset based APIs assume the RNG state to be a CUDA like offset based. CPU RNG state doesn’t work and fails as shown below -
[rank0]: File ".../.local/lib/python3.10/site-packages/torch/distributed/_tensor/random.py", line 176, in _distribute_region
[rank0]: old_offset = self.get_offset("parallel-rng")
[rank0]: File ".../.local/lib/python3.10/site-packages/torch/distributed/_tensor/random.py", line 195, in get_offset
[rank0]: return int(offset_tensor.item())
[rank0]: RuntimeError: a Tensor with 631 elements cannot be converted to Scalar
Is there a plan to enhance the DTensor to use CPU like RNG state, or is it possible to mark is_rng_supported_mesh to be false for backends that use CPU RNG state and not the offset based RNG state?