Impact of multithreading and local caching on torch.compile

Hi, sometimes I need to load models on different gpus and use different threads to run it. The example codes is:

import torch
import time
import threading
from torchvision import models
import numpy as np
from torchvision.transforms.functional import center_crop, resize, to_tensor, normalize
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights

def test_mobilenet2():
    model1 = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
    model1.eval()
    model1.to("cuda:0")

    model2 = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
    model2.eval()
    model2.to("cuda:1")

    fake_image_np = np.random.rand(3, 256, 256).astype(np.float32)
    fake_image_tensor = torch.tensor(fake_image_np)
    fake_image_tensor = resize(fake_image_tensor, [224, 224])
    fake_image_tensor = fake_image_tensor.unsqueeze(0)
    fake_image_tensor = normalize(fake_image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    fake_image_tensor1 = fake_image_tensor.to("cuda:0")
    fake_image_tensor2 = fake_image_tensor.to("cuda:1")

    model1 = torch.compile(model1)
    model2 = torch.compile(model2)


    def infer1(image):
        with torch.no_grad(): 
            outputs = model1(image)
            _, predicted = torch.max(outputs, 1) 

    def infer2(image):
        with torch.no_grad(): 
            outputs = model2(image)
            _, predicted = torch.max(outputs, 1) 

    t1 = time.time()
    job1 = threading.Thread(target=infer1, args=(fake_image_tensor1,))
    job2 = threading.Thread(target=infer2, args=(fake_image_tensor2,))
    job1.start()
    job2.start()

    job1.join()
    job2.join()

    # infer1(fake_image_tensor1)
    print("cost", time.time() - t1)

if __name__ == "__main__":
    test_mobilenet2()

The compile time is much longer when I used 2 threads.

Column 1 Column 2
thread num time/s
1 17
2 27

When I actually use it on my project(a larger model), this time becomes extremely long, about an hour. And also, at run time, some recompilations are triggered due to changes in the model.

I have two questions:

  1. How should I run two torch.compile tasks in parallel? Process is also difficult, because I need to set process to spawn mode and refactor pipeline(the pipeline is very complex, like ComfyUI)
  2. Is torch.compile thread-safe? Sometimes some exceptions will be thrown.

first exception:

[1]    179237 segmentation fault (core dumped)

second exception:

Exception in thread Thread-3 (infer2):
Traceback (most recent call last):
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/root/workspace/pythonProject/test_compile/test_compile_cache.py", line 96, in infer2
    outputs = model2(image)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
    result = inner_convert(
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 703, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 335, in call_function
    return super().call_function(tx, args, kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 289, in call_function
    return super().call_function(tx, args, kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 272, in call_function
    tx.call_function(
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 336, in call_function
    return tx.inline_user_function_return(
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1260, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 335, in call_function
    return super().call_function(tx, args, kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 289, in call_function
    return super().call_function(tx, args, kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 272, in call_function
    tx.call_function(
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 272, in call_function
    tx.call_function(
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 309, in call_function
    return wrap_fx_proxy(
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1330, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1471, in wrap_fx_proxy_cls
    return target_cls(proxy, **options)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 123, in __call__
    obj = type.__call__(cls, *args, **kwargs)
torch._dynamo.exc.InternalTorchDynamoError: TensorVariable.__init__() missing 8 required keyword-only arguments: 'dtype', 'device', 'layout', 'ndim', 'requires_grad', 'is_quantized', 'is_sparse', and 'class_type'

from user code:
   File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torchvision/models/mobilenetv2.py", line 174, in forward
    return self._forward_impl(x)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torchvision/models/mobilenetv2.py", line 166, in _forward_impl
    x = self.features(x)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torchvision/models/mobilenetv2.py", line 62, in forward
    return x + self.conv(x)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

cost 17.446489334106445
Exception ignored in atexit callback: <function dump_compile_times at 0x7f76b2d73250>
Traceback (most recent call last):
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 320, in dump_compile_times
    log.info(compile_times(repr="str", aggregate=True))
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 307, in compile_times
    out += tabulate(rows, headers=("Function", "Runtimes (s)"))
  File "/root/env/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 121, in tabulate
    import tabulate
  File "<frozen importlib._bootstrap>", line 1024, in _find_and_load
  File "<frozen importlib._bootstrap>", line 174, in __exit__
  File "<frozen importlib._bootstrap>", line 127, in release
AttributeError: '_ModuleLock' object has no attribute 'lock'

There is a lock around the compiler, so two threads can’t be compiling something at the same time, they should each wait their turn. Though running the compiled code should happen in parallel.

That segfault/error seems like a bug that we should fix, I created an issue to track it here:

As a workaround you could try doing a warmup run on the main thread to ensure they are already compiled before you start.

Another possible workaround would be to use multiprocessing, which often is faster than threading in Python because of the GIL.

1 Like

Is it possible to avoid the warmup process by saving compiled results to a cache dir?

export TORCHINDUCTOR_FX_GRAPH_CACHE=1
export TORCHINDUCTOR_CACHE_DIR=compile_cache_dir_$config

From my tests, it seems to greatly speed up the compilation. What is the reason for this?

Enabling that cache will speed compile up a lot – but should not effect thread safety.