PyTorch Runtime Error with Compiled Autograd

I’m encountering a runtime error in PyTorch while running the backward pass in my deep learning model with compiled autograd enabled and world_size == 2. The error appears to be related to an invalid index on a 0-dimensional tensor. Below is a summary of the issue:

Example of Enabling Compiled Autograd:

def enable_compiled_autograd(**kwargs):
    def compiler_fn(gm):
        return torch.compile(gm, backend="my_backend", options={"inference": False}, **kwargs)

    torch._C._dynamo.compiled_autograd.set_autograd_compiler(
        functools.partial(compiled_autograd.AutogradCompilerInstance, compiler_fn)
    )

    torch._dynamo.reset()
    torch._dynamo.config.optimize_ddp = "python_reducer"

Error Details:

TorchRuntimeError: Failed running call_function <built-in function getitem>(*(Parameter(FakeTensor(..., device='my_device:0', size=(), requires_grad=True)), 0), **{}):
invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number.

Stack Trace:

File "torch/_tensor.py", line 535, in backward
    torch.autograd.backward(
File "torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
File "torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
File "torch/nn/modules/module.py", line 1535, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1544, in _call_impl
    return forward_call(*args, **kwargs)
File "torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
File "torch/fx/graph_module.py", line 737, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
File "torch/fx/graph_module.py", line 317, in __call__
    raise e
File "torch/fx/graph_module.py", line 304, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
File "torch/nn/modules/module.py", line 1535, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1544, in _call_impl
    return forward_call(*args, **kwargs)
File "torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "torch/_dynamo/convert_frame.py", line 786, in _convert_frame
    result = inner_convert(
File "torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
File "contextlib.py", line 79, in inner
    return func(*args, **kwds)
File "torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
File "torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
File "torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out code = transform_code_object(code, transform)
File "torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
File "torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
File "torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
File "torch/_dynamo/symbolic_convert.py", line 2099, in run
    super().run()
File "torch/_dynamo/symbolic_convert.py", line 815, in run
    and self step()
File "torch/_dynamo/symbolic_convert.py", line 778, in step
    getattr(self, inst.opname)(inst)
File "torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
File "torch/_dynamo/symbolic_convert.py", line 247, in impl
    self push(fn_var.call_function(self, self popn(nargs), {}))
File "torch/_dynamo/variables/builtin.py", line 935, in call_function
    return handler(tx, args, kwargs)
File "torch/_dynamo/variables/builtin.py", line 914, in _handle_insert_op_in_graph
    return wrap_fx_proxy(tx, proxy)
File "torch/_dynamo/variables/builder.py", line 1330, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "torch/_dynamo/variables/builder.py", line 1415, in wrap_fx_proxy_cls
    example value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "torch/_dynamo/utils.py", line 1722, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
File "torch/_dynamo/utils.py", line 1664, in get_fake_value
    ret_val = wrap_fake_exception(
File "torch/_dynamo/utils.py", line 1198, in wrap_fake_exception
    return fn()
File "torch/_dynamo/utils.py", line 1665, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "torch/_dynamo/utils.py", line 1790, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
File "torch/_dynamo/utils.py", line 1772, in run_node
    return node target(*args, **kwargs)

Compiled Autograd Logs:
With world_size == 2 (Fail with scalar):

.... 
[__compiled_autograd]         getitem_2004 = hooks[397]
[__compiled_autograd]         call_hook_396 = torch__dynamo_external_utils_call_hook(getitem_2004, getitem_2);  getitem_2004 = getitem_2 = None
[__compiled_autograd]         getitem_2005: "f32[768]" = call_hook_396[0];  call_hook_396 = getitem_2005 = None
[__compiled_autograd]         accumulate_grad__397 = torch.ops.inductor.accumulate_grad_.default(getitem_1052, embedding_dense_backward_3);  embedding_dense_backward_3 = accumulate_grad__397 = None
[__compiled_autograd]         getitem_2006 = hooks[398]
[__compiled_autograd]         call_hook_397 = torch__dynamo_external_utils_call_hook(getitem_2006, getitem_1052);  getitem_2006 = getitem_1052 = None
[__compiled_autograd]         getitem_2007: "f32[50, 768]" = call_hook_397[0];  call_hook_397 = getitem_2007 = None
[__compiled_autograd]         accumulate_grad__398 = torch.ops.inductor.accumulate_grad_.default(getitem_1053, _to_copy_521);  _to_copy_521 = accumulate_grad__398 = None
[__compiled_autograd]         getitem_2008 = hooks[399]
[__compiled_autograd]         call_hook_398 = torch__dynamo_external_utils_call_hook(getitem_2008, getitem_1053);  getitem_2008 = getitem_1053 = None
[__compiled_autograd]         getitem_2009: "f32[768, 3, 32, 32]" = call_hook_398[0];  call_hook_398 = getitem_2009 = None
[__compiled_autograd]         accumulate_grad__399 = torch.ops.inductor.accumulate_grad_.default(getitem_1054, _to_copy_7);  _to_copy_7 = accumulate_grad__399 = None
[__compiled_autograd]         getitem_2010 = hooks[400]
[__compiled_autograd]         call_hook_399 = torch__dynamo_external_utils_call_hook(getitem_2010, getitem_1054);  getitem_2010 = getitem_1054 = None
[__compiled_autograd]         getitem_2011: "f32[]" = call_hook_399[0];  call_hook_399 = getitem_2011 = None
[__compiled_autograd]         accumulate_grad__400 = torch.ops.inductor.accumulate_grad_.default(getitem_1055, view_553);  view_553 = accumulate_grad__400 = None
[__compiled_autograd]         getitem_2012 = hooks[401];  hooks = None
[__compiled_autograd]         call_hook_400 = torch__dynamo_external_utils_call_hook(getitem_2012, getitem_1055);  getitem_2012 = getitem_1055 = None
[__compiled_autograd]         getitem_2013: "f32[768]" = call_hook_400[0];  call_hook_400 = getitem_2013 = None
[__compiled_autograd]         return []

With world_size == 1 (No fail with scalar and works good):

.... 
[__compiled_autograd]         accumulate_grad__393 = torch.ops.inductor.accumulate_grad_.default(getitem_5, _to_copy_516);  getitem_5 = _to_copy_516 = accumulate_grad__393 = None
[__compiled_autograd]         accumulate_grad__394 = torch.ops.inductor.accumulate_grad_.default(getitem_4, _to_copy_515);  getitem_4 = _to_copy_515 = accumulate_grad__394 = None
[__compiled_autograd]         accumulate_grad__395 = torch.ops.inductor.accumulate_grad_.default(getitem_3, _to_copy_518);  getitem_3 = _to_copy_518 = accumulate_grad__395 = None
[__compiled_autograd]         accumulate_grad__396 = torch.ops.inductor.accumulate_grad_.default(getitem_2, _to_copy_517);  getitem_2 = _to_copy_517 = accumulate_grad__396 = None
[__compiled_autograd]         accumulate_grad__397 = torch.ops.inductor.accumulate_grad_.default(getitem_1052, embedding_dense_backward_3);  getitem_1052 = embedding_dense_backward_3 = accumulate_grad__397 = None
[__compiled_autograd]         accumulate_grad__398 = torch.ops.inductor.accumulate_grad_.default(getitem_1053, _to_copy_521);  getitem_1053 = _to_copy_521 = accumulate_grad__398 = None
[__compiled_autograd]         accumulate_grad__399 = torch.ops.inductor.accumulate_grad_.default(getitem_1054, _to_copy_7);  getitem_1054 = _to_copy_7 = accumulate_grad__399 = None
[__compiled_autograd]         accumulate_grad__400 = torch.ops.inductor.accumulate_grad_.default(getitem_1055, view_553);  getitem_1055 = view_553 = accumulate_grad__400 = None
[__compiled_autograd]         return []

Are there any additional logs or debugging tips that could help me understand why the getitem function is being applied to a scalar and/or how isolate problem ?

Hi, I would double check how you are wrapping your DDP model with torch.compile:

Happy to take a look if you have a repro, file a github issue and tag me @xmfan

1 Like