Decomposition slows down the lazy tensor tracing

The decomposition slows down the lazy tensor tracing when running with TorchXLA. Here is the timeline with stack info:

Removing the decomposition-related code resolves the issue:

This is the timeline after removing the decomposition-related code:

What confused me is that the Torch native job has no such issue. Only one cudaLaunchKernel Op there below aten::lerp_ ( Using _single_tensor_adamw is similar to TorchXLA’s implementation in AdamW.):

My question is: Is there any flag to control the decomposition?

1 Like

Hi @AngWang, we’ve observed some issues that reported increase in tracing time, but I’m not sure if this is related to it since this issue seems to be strictly due to decomposition. PyTorch/XLA does have a flag that controls functionalization (which may introduce more decomposition), so you can try to use this flag to disable functionalization to test if it’ll impact your code – the flag is XLA_DISABLE_FUNCTIONALIZATION. You can set this env_var to true like XLA_DISABLE_FUNCTIONALIZATION=1 python your_code.py.

Also, do you have a reproducible snippet of what’s happening in your code so I can also take a look? Thanks!

Thanks @wonjoolee .

By setting XLA_DISABLE_FUNCTIONALIZATION=1 , the tracing becomes normal. But it will break the fsdp job with flatten_parameters=True. I’ll take a look into it.

Patch the code below to reproduce the slow tracing issue. Here is the runing command:
XLA_DISABLE_FUNCTIONALIZATION=0 PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc_per_node 1 test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py
index 43c4c96..7af3bdc 100644
--- a/test/test_train_mp_imagenet.py
+++ b/test/test_train_mp_imagenet.py
@@ -266,10 +266,9 @@ def train_imagenet():
   writer = None
   if xm.is_master_ordinal():
     writer = test_utils.get_summary_writer(FLAGS.logdir)
-  optimizer = optim.SGD(
+  optimizer = optim.AdamW(
       model.parameters(),
       lr=FLAGS.lr,
-      momentum=FLAGS.momentum,
       weight_decay=1e-4)
   num_training_steps_per_epoch = train_dataset_len // (
       FLAGS.batch_size * xm.xrt_world_size())
@@ -289,6 +288,11 @@ def train_imagenet():
   def train_loop_fn(loader, epoch):
     tracker = xm.RateTracker()
     model.train()
+    prof = torch.profiler.profile(
+          activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU],
+          schedule=torch.profiler.schedule(wait=2, warmup=2, active=3),
+          with_stack=True,
+          on_trace_ready=torch.profiler.tensorboard_trace_handler("./profile"))
     for step, (data, target) in enumerate(loader):
       with xp.StepTrace('train_imagenet'):
         with xp.Trace('build_graph'):
@@ -306,6 +310,8 @@ def train_imagenet():
         if step % FLAGS.log_steps == 0:
           xm.add_step_closure(
               _train_update, args=(device, step, loss, tracker, epoch, writer))
+      xm.mark_step()
+      prof.step()

   def test_loop_fn(loader, epoch):
     total_samples, correct = 0, 0

Thanks @AngWang for the reproducible code. We’ll also look into it. And just to confirm, which version of PyTorch and PyTorch/XLA were you using to run this?

This is the version info:

>>> torch_xla.version.__xla_gitrev__
'efa6fcfdac5368330a0770e9019649eba08b5f56'
>>> torch_xla.version.__torch_gitrev__
'f6dfbffb3bb46ada6fe66b5da4f989f9d4d69b3c'

Maybe this issue is related to the MaybeWrapTensorToFunctional function (xla/torch_xla/csrc/torch_util.cpp at 9f1afbd3bb1baab0049178da1e625a886ff27b30 · pytorch/xla · GitHub).

After being wrapped into a functional tensor, I think torch dispatches the op back to the python level in torch::impl::dispatch::PythonKernelHolder, which is likely the cause of slow tracing.