Conditional Decompositions

Hi,

I am trying to implement a conditional decomposition for the torch.aten.mul operator:

In most cases, it should not change, and keep being the torch.aten.mul operator.
However, in specific cases, I want to decompose the mul to other ops that my flow supports.

I tried to implement it as a custom decomposition, but I am not able to find a solution to leave the operator untouched, if my condition does not apply.

from typing import Sequence, Union, Dict, Callable

import torch
from torch._ops import OperatorBase, OpOverloadPacket
from torch._decomp import register_decomposition
from torch._prims_common.wrappers import out_wrapper

@register_decomposition(aten_op=aten.mul)
@out_wrapper()
def replace_mul(input_a : Union[torch.Tensor, bool, int, float, complex], input_b : Union[torch.Tensor, bool, int, float, complex]) -> torch.Tensor:
    
  if <my condition>:
    return <my custom decomposition>

  # else: we just want to keep the standard mul operator
  ... 
  # return None ## did not work
  return torch.mul(input_a, input_b)  ## ends in endless recursive function calls

I tried returning torch.mul, but this results in endless recursive calls of the decomposition.
I also tried returning None, but this fails with an AssertionError.

Is there a mechanism to enable conditional decompositions or an alternative approach that can be used to solve this issue?