Decomposition slows down the lazy tensor tracing

Thanks @wonjoolee .

By setting XLA_DISABLE_FUNCTIONALIZATION=1 , the tracing becomes normal. But it will break the fsdp job with flatten_parameters=True. I’ll take a look into it.

Patch the code below to reproduce the slow tracing issue. Here is the runing command:
XLA_DISABLE_FUNCTIONALIZATION=0 PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc_per_node 1 test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1

diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py
index 43c4c96..7af3bdc 100644
--- a/test/test_train_mp_imagenet.py
+++ b/test/test_train_mp_imagenet.py
@@ -266,10 +266,9 @@ def train_imagenet():
   writer = None
   if xm.is_master_ordinal():
     writer = test_utils.get_summary_writer(FLAGS.logdir)
-  optimizer = optim.SGD(
+  optimizer = optim.AdamW(
       model.parameters(),
       lr=FLAGS.lr,
-      momentum=FLAGS.momentum,
       weight_decay=1e-4)
   num_training_steps_per_epoch = train_dataset_len // (
       FLAGS.batch_size * xm.xrt_world_size())
@@ -289,6 +288,11 @@ def train_imagenet():
   def train_loop_fn(loader, epoch):
     tracker = xm.RateTracker()
     model.train()
+    prof = torch.profiler.profile(
+          activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU],
+          schedule=torch.profiler.schedule(wait=2, warmup=2, active=3),
+          with_stack=True,
+          on_trace_ready=torch.profiler.tensorboard_trace_handler("./profile"))
     for step, (data, target) in enumerate(loader):
       with xp.StepTrace('train_imagenet'):
         with xp.Trace('build_graph'):
@@ -306,6 +310,8 @@ def train_imagenet():
         if step % FLAGS.log_steps == 0:
           xm.add_step_closure(
               _train_update, args=(device, step, loss, tracker, epoch, writer))
+      xm.mark_step()
+      prof.step()

   def test_loop_fn(loader, epoch):
     total_samples, correct = 0, 0