jansel
March 8, 2023, 11:14pm
4
You can call this function to register a backend:
class CompiledFn(Protocol):
def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
...
CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
_BACKENDS: Dict[str, CompilerFn] = dict()
def register_backend(
compiler_fn: Optional[CompilerFn] = None,
name: Optional[str] = None,
tags: Sequence[str] = (),
):
"""
Decorator to add a given compiler to the registry to allow calling
`torch.compile` with string shorthand. Note: for projects not
imported by default, it might be easier to pass a function directly
as a backend and not use a string.
This allows you to use a string alias for the backend=... arg to torch.compile(.., backend="some_name"). You can also just pass your backend directly to torch.compile(..., backend=my_backend), no need to register.
To get the aten/prim graph use:
import torch
from torch._dynamo import eval_frame
from torch._dynamo.utils import counters
from torch._functorch.aot_autograd import aot_module_simplified
from torch._subclasses import FakeTensor
from torch.utils._python_dispatch import _disable_current_modes
log = logging.getLogger(__name__)
def aot_autograd(**kwargs):
def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
import functorch.compile
# Hack to get around circular import problems with aot_eager_decomp_partition
if callable(kwargs.get("decompositions")):
kwargs["decompositions"] = kwargs["decompositions"]()
# TODO: stop monkeypatching here (without even cleaning up, UGH!)
functorch.compile.config.use_functionalize = True
functorch.compile.config.use_fake_tensor = True
1 Like