Take Batchnorm2d
as an example.
My bachend is this:
def xtorch_compile_backend(
gm: torch.fx.GraphModule, example_inputs: Sequence[Any], **kwargs
) -> torch.nn.Module:
try:
fake_mode = detect_fake_mode(example_inputs)
with unittest.mock.patch.object(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
gm = apply_pre_aot_passes(gm)
logger.debug(f"After pre-AOT passes: \n{gm.graph}")
gm = aot_export_joint_simple(
gm,
example_inputs,
trace_joint=False,
decompositions=get_decompositions(),
)
logger.debug(f"After aot export: \n{gm.graph}")
gm = apply_post_aot_passes(gm)
logger.debug(f"After post-AOT and remove_sym_nodes passes: \n{gm.graph}")
converted_module = convert_module(gm, example_inputs, options)
return converted_module
I need convert placeholders
to get_attr
, thus i use aot_export_joint_simple
.
torch.nn.functional.batch_norm
will be conver to torch.ops.aten._native_batch_norm_legit.no_stats
, which is same for many other ops. How to prevent this.