[RFC] Adding a Batch Norm Backend: Revisiting Dispatch Stack Issues

Summary

PyTorch’s batch norm dispatch has two parallel stacks: the old stack (_batch_norm_impl_index) that is active today, and a new stack (_batch_norm_with_update / _batch_norm_no_update) that was added in #116092 (March 2024) but never activated. The new stack is dead code. Its entry point in Normalization.cpp has been commented out for two years with a TODO that reads “switch to the new stack after the 2 week FC window.”

This post is a request for discussion on whether and how to complete this migration. We (AMD ROCm team) are currently integrating a new batch norm backend (hipDNN) and have run directly into the architectural consequences of the half-migrated state. We’d like to help move this forward, and we think the benefits extend well beyond our use case.

Background: The Two Stacks

Old stack (active):

batch_norm()
  → _batch_norm_impl_index()          # CompositeImplicitAutograd, in libtorch_cpu.so
    → _select_batch_norm_backend()    # returns Cudnn | Miopen | Native
    → at::cudnn_batch_norm(...)       # CUDA dispatch → cudnn lib
    or at::miopen_batch_norm(...)     # CUDA dispatch → miopen lib
    or at::native_batch_norm(...)     # CUDA dispatch → native kernels

Each backend has its own op in native_functions.yaml, its own autograd formula in derivatives.yaml, its own decomposition in decompositions.py, and its own vmap batching rule in BatchRulesNorm.cpp. Adding a new backend requires touching all four.

New stack (dead code):

batch_norm()
  → _batch_norm_with_update()         # dispatches to CUDA
    → _batch_norm_with_update_cuda()  # calls _select_batch_norm_backend() internally
      → cudnn / miopen / native       # backend selection happens GPU-side

The key difference: backend selection moves inside the CUDA dispatch. The frontend sees one op (_batch_norm_with_update), not one-per-backend. Autograd, decompositions, and vmap only need to handle the single op. A new backend can be added by modifying _batch_norm_with_update_cuda(). No yaml changes, no autograd formulas, no decompositions, no vmap rules.

Why This Matters Now

We are in the process of adding hipDNN batch norm support to PyTorch ([ROCm] Add hipDNN backend support for batch norm by zjgarvey · Pull Request #177534 · pytorch/pytorch · GitHub). Because the new stack is not active, we had to follow the old-stack pattern: register hipdnn_batch_norm / hipdnn_batch_norm_backward as dispatch ops, add autograd formulas, add decompositions, add vmap batching rules, add FC allowlist entries.

This works, and we’re shipping it. But it highlights a scaling problem: every new backend library that wants to integrate batch norm (or any operation with this pattern) must add ~7 files worth of frontend boilerplate. The new stack was designed to eliminate exactly this boilerplate.

What’s Blocking the Migration

Based on our analysis of the code and history, three things need to happen before the switch can be flipped:

1. The eval path needs backend selection

_batch_norm_no_update is registered as CompositeExplicitAutograd:

- func: _batch_norm_no_update(...)
  dispatch:
    CompositeExplicitAutograd: _batch_norm_no_update

Its implementation just calls at::native_batch_norm(...) directly, no backend selection required. If the switch were flipped today, cuDNN and MIOpen would stop being used in eval mode. This is a performance regression that would affect every user.

Fix: Add a CUDA dispatch key to _batch_norm_no_update with a _batch_norm_no_update_cuda implementation that calls _select_batch_norm_backend(), matching the pattern already used by _batch_norm_with_update_cuda.

2. The cudnn_enabled parameter is not threaded through

_batch_norm_impl_index accepts a cudnn_enabled parameter (passed from F.batch_normtorch.batch_norm). Although _select_batch_norm_backend actually reads this flag from globalContext() rather than from the parameter (the parameter is vestigial: there’s a TODO in Normalization.cpp to remove it), the new stack ops don’t accept this parameter at all. This is probably fine since the global context is authoritative, but it should be verified that removing the parameter doesn’t change behavior in any edge case.

3. Two years of untested interaction

The new stack was added to the codebase in March 2024. Since then:

  • MIOpen batch norm has gained BF16 NCHW mixed precision (#154611), channels-last 3D support (#160529), and output format fixes (#162112), all on the old stack only.
  • The new stack’s _batch_norm_with_update_cuda still references _select_batch_norm_backend, but any conditions added to that function since March 2024 have never been tested through the new stack’s code path.
  • No tests exercise the new stack through batch_norm() (tests that exist call the new-stack ops directly, which doesn’t validate the end-to-end path).

Fix: Run the full batch norm test suite with the switch flipped and fix any failures. We’re willing to help with this on the ROCm side.

Proposal

We’d like to understand whether upstream is interested in completing this migration. Specifically:

  1. Is the migration still desired? The original PR’s intent was clear, but priorities may have changed. If the answer is “no, we’re keeping the old stack,” we’ll continue with per-backend dispatch ops and adapt accordingly.

  2. If yes, what’s the activation plan? We would propose:

    • Fix _batch_norm_no_update to add CUDA dispatch with backend selection
    • Verify cudnn_enabled handling
    • Run comprehensive tests with the switch flipped (we can contribute ROCm-side testing)
    • Land behind a runtime flag first (e.g., an environment variable) for gradual rollout
    • Remove the old stack after the BC window
  3. Can we help? We have looked closely at the current dispatch architecture and are actively working in this code. We’re happy to contribute patches for the eval-path fix and testing.

Why This Benefits Everyone

The current architecture has costs beyond backend integration:

  • Dead code maintenance: The new stack ops, their decompositions, their backward formulas, and their test coverage all exist in the codebase but are never exercised through production paths. This is maintenance burden with no benefit until the switch is flipped.

  • Fragile dispatch: The old stack encodes backend decisions as opaque integers (impl_index) in the autograd tape. The new stack uses proper dispatch keys. The integer encoding is a JIT-era design that predates modern dispatch.

Context

  • Open stalled PR for migration: #119496
  • Original consolidation PR: #116092
  • Numerics divergence issue: #111384
  • hipDNN integration PR: #177534
  • hipDNN (ROCm libraries): rocm-libraries/projects/hipdnn

I failed to mention an important reason this has probably stalled:

Optional inplace tensors (track_running_stats=False)

Batch norm’s running_mean and running_var are optional (they are None when track_running_stats=False) but also mutated in-place during training when present. The dispatcher schema cannot express Tensor?(a!), a tensor that is both optional and inplace. You must choose one:

  • Tensor? (optional, no mutation annotation): the dispatcher doesn’t track the mutation, breaking functionalization, FX tracing, and any transform that relies on schema accuracy. This is the known-broken native_batch_norm schema.

  • Tensor(a!) (non-optional, mutation tracked): correct for mutation tracking, but the tensor cannot be None.

The new stack chose Tensor(a!) for _batch_norm_with_update:


- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias,

Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps)

 → (Tensor, Tensor, Tensor, Tensor)

The intent was for _batch_norm_no_update to handle cases without running stats. But the commented-out switch code in Normalization.cpp doesn’t do this correctly. It routes all training cases to _batch_norm_with_update, including track_running_stats=False where running stats are undefined:


// if (training) {

//   …

//   return std::get<0>(at::_batch_norm_with_update(

//       input, weight, bias,

//       const_cast<Tensor&>(running_mean),  // may be undefined!

//       const_cast<Tensor&>(running_var),   // may be undefined!

//       momentum, eps));

The const_cast<Tensor&> passes an undefined tensor as a non-optional Tensor(a!) argument. This is the same fundamental issue that was recognized when creating the _native_batch_norm_legit family. That family solved it by splitting into three ops: _native_batch_norm_legit (training with stats), .no_stats (training without stats), and _legit_no_training (eval). The new stack only has two ops (_with_update and _no_update) and has no .no_stats variant.

We’d need to address this gap when completing the migration. I’d suggest adding a _batch_norm_with_update.no_stats variant (matching the _legit family’s approach), or fix the switch logic to route training-without-stats through _batch_norm_no_update.