Detect shared weights across subgraphs in torch.compile

I have a model that shares the same weight tensor across multiple layers. After compiling the model, it’s split into three subgraphs, and that shared weight appears in both the first and third subgraphs. In my custom backend, I need to:

  1. Identify which layers (or ops) reference the same weight tensor.
  2. Retrieve the raw data pointer of that shared weight.

How can I find at backend compilation time, which subgraph nodes are using the same weight tensor and then extract its underlying data pointer?

Below is an example of model with shared weights -

import torch

def custom_backend(graph_module, example_inputs):
    print("We are in custom backend")
    
    # Use the inductor backend for actual compilation
    from torch._inductor.compile_fx import compile_fx
    return compile_fx(graph_module, example_inputs)

class Custom_Model_Shared_Weights(torch.nn.Module):
    def __init__(self, vocab=64, d_model=16):
        super().__init__()
        # one Parameter object shared by both layers
        shared = torch.nn.Parameter(torch.randn(vocab, d_model).to(torch.bfloat16))

        self.embed = torch.nn.Embedding(vocab, d_model, _weight=shared)
        self.proj = torch.nn.Linear(d_model, vocab, bias=False)
        self.proj.weight = shared  # <- tie

    def forward(self, tokens):
        x = self.embed(tokens)
        hid = x.mean(dim=1)
        # ── graph-break op ──
        _ = torch.nonzero(hid)  # dynamic -> Inductor can't infer len

        # ── subgraph ② (will compile) ──
        return self.proj(hid)  # uses the *same* weight again

model = Custom_Model_Shared_Weights()
model.eval()
inp = torch.randint(0, 64, (4, 10), dtype=torch.long)

with torch.no_grad():
    model = torch.compile(model, backend=custom_backend)
    ref_out = model(inp)

print("Output:", ref_out.shape)