The first issue I’ve noticed is that there’s no support for inplace ops.
For example, this graph without inplace ops:
graph(%0 : Long(10000, 13000, strides=[13000, 1], requires_grad=0, device=cuda:0),
%1 : Long(10000, 13000, strides=[13000, 1], requires_grad=0, device=cuda:0),
%2 : Int(10000, 13000, strides=[13000, 1], requires_grad=0, device=cuda:0)):
%11 : int = prim::Constant[value=4]()
%8 : int = prim::Constant[value=1]()
%6 : bool = prim::Constant[value=0]()
%5 : Device = prim::Constant[value="cuda:0"]()
%4 : NoneType = prim::Constant()
%3 : int = prim::Constant[value=3]()
%7 : Tensor = aten::_to_copy(%0, %3, %4, %5, %4, %6, %4)
%9 : Tensor = aten::add(%7, %1, %8)
%10 : Tensor = aten::mul(%9, %2)
%14 : Tensor = aten::to(%10, %11, %6, %6, %4)
return (%14)
Gets fused:
[DEBUG cuda_codegen.cpp:1032] extern "C" __global__
[DEBUG cuda_codegen.cpp:1032] void fused_add_mul(int* tv_, int* tv__, long long* tv___, long long* aten_mul) {
[DEBUG cuda_codegen.cpp:1032] {
[DEBUG cuda_codegen.cpp:1032] if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<130000000ll ? 1 : 0) {
[DEBUG cuda_codegen.cpp:1032] int tv__1 = __ldg(tv_ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
[DEBUG cuda_codegen.cpp:1032] long long v = (long long)(__ldg(tv__ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)));
[DEBUG cuda_codegen.cpp:1032] long long v_1 = __ldg(tv___ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
[DEBUG cuda_codegen.cpp:1032] aten_mul[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = (long long)(tv__1) * v + (long long)(tv__1) * v_1;
[DEBUG cuda_codegen.cpp:1032] }}
[DEBUG cuda_codegen.cpp:1032] }
But if you change the mul to mul_ no fusion happens. Note that changing add → add_ & mul → mul_ yields an equivalent program. (side comment is: why doesn’t TorchScript optimize non-inplace into inplace?)
I also wonder why it changes the expression (a + b) * c
into (a * c) + (b * c)
; is that actually better?
I’ve noticed that sometimes similar programs don’t get fully fused; only the mul + to. But I didn’t manage to find out yet why the add is left out sometimes.