Detect shared weights across subgraphs in torch.compile

I have a model that shares the same weight tensor across multiple layers. After compiling the model, it’s split into three subgraphs, and that shared weight appears in both the first and third subgraphs. In my custom backend, I need to:

  1. Identify which layers (or ops) reference the same weight tensor.
  2. Retrieve the raw data pointer of that shared weight.

How can I find at backend compilation time, which subgraph nodes are using the same weight tensor and then extract its underlying data pointer?

Below is an example of model with shared weights -

import torch

def custom_backend(graph_module, example_inputs):
    print("We are in custom backend")
    
    # Use the inductor backend for actual compilation
    from torch._inductor.compile_fx import compile_fx
    return compile_fx(graph_module, example_inputs)

class Custom_Model_Shared_Weights(torch.nn.Module):
    def __init__(self, vocab=64, d_model=16):
        super().__init__()
        # one Parameter object shared by both layers
        shared = torch.nn.Parameter(torch.randn(vocab, d_model).to(torch.bfloat16))

        self.embed = torch.nn.Embedding(vocab, d_model, _weight=shared)
        self.proj = torch.nn.Linear(d_model, vocab, bias=False)
        self.proj.weight = shared  # <- tie

    def forward(self, tokens):
        x = self.embed(tokens)
        hid = x.mean(dim=1)
        # ── graph-break op ──
        _ = torch.nonzero(hid)  # dynamic -> Inductor can't infer len

        # ── subgraph ② (will compile) ──
        return self.proj(hid)  # uses the *same* weight again

model = Custom_Model_Shared_Weights()
model.eval()
inp = torch.randint(0, 64, (4, 10), dtype=torch.long)

with torch.no_grad():
    model = torch.compile(model, backend=custom_backend)
    ref_out = model(inp)

print("Output:", ref_out.shape)

There is no built-in tracking for this, but you could build a dict mapping weight tensors to their usage in prior graphs.

@jansel
How do i identify the same weight across subgraphs? Param_dic/State_dict are not available in custom backend, the data pointers are also not available in the fx graph passes. How do i confirm arg is same across graphs.