Thanks @wonjoolee .
By setting XLA_DISABLE_FUNCTIONALIZATION=1
, the tracing becomes normal. But it will break the fsdp job with flatten_parameters=True. I’ll take a look into it.
Patch the code below to reproduce the slow tracing issue. Here is the runing command:
XLA_DISABLE_FUNCTIONALIZATION=0 PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc_per_node 1 test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py
index 43c4c96..7af3bdc 100644
--- a/test/test_train_mp_imagenet.py
+++ b/test/test_train_mp_imagenet.py
@@ -266,10 +266,9 @@ def train_imagenet():
writer = None
if xm.is_master_ordinal():
writer = test_utils.get_summary_writer(FLAGS.logdir)
- optimizer = optim.SGD(
+ optimizer = optim.AdamW(
model.parameters(),
lr=FLAGS.lr,
- momentum=FLAGS.momentum,
weight_decay=1e-4)
num_training_steps_per_epoch = train_dataset_len // (
FLAGS.batch_size * xm.xrt_world_size())
@@ -289,6 +288,11 @@ def train_imagenet():
def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
model.train()
+ prof = torch.profiler.profile(
+ activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU],
+ schedule=torch.profiler.schedule(wait=2, warmup=2, active=3),
+ with_stack=True,
+ on_trace_ready=torch.profiler.tensorboard_trace_handler("./profile"))
for step, (data, target) in enumerate(loader):
with xp.StepTrace('train_imagenet'):
with xp.Trace('build_graph'):
@@ -306,6 +310,8 @@ def train_imagenet():
if step % FLAGS.log_steps == 0:
xm.add_step_closure(
_train_update, args=(device, step, loss, tracker, epoch, writer))
+ xm.mark_step()
+ prof.step()
def test_loop_fn(loader, epoch):
total_samples, correct = 0, 0