NNsight and pytorch and large model remote inference

I am writing about a library built on pytorch (“nnsight”) that my team has been building, that is designed to enable remote science on very large neural networks, and scalable multi-tenant customization. It is called “NNsight” (https://nnsight.net/), and it defines “tracing contexts” that build computation graphs (using fake tensors/metatensors) that can be used to define interventions and customized remote execution of a large model.

In JAX they have just released Penzai GitHub - google-deepmind/penzai: A JAX research toolkit for building, editing, and visualizing neural networks., which attempts to cover similar use-cases.

On NNsight, we also considered a similar API style, but we chose python context managers instead because we think they’re easier to use, allowing more readable complex customizations to be written in ordinary pytorch code. In addition, we think the “remote execution” usecase is important and we have made sure that it is part of the core.

As we have developed and used the system, we have realized that this kind of infrastructure may actually be a pretty important technology and could be a key part of a solution to enable scientists and developers to work with upcoming 500b-class open models (for which it will be exorbitant and complex for people to rent their own machines just to run them). If built correctly, nnsight (or a similar system) could be both strategically important to meta in supporting very large open model use, and it could be good for improving transparency and standardization in the AI ecosystem as a whole.

Right now the team building nnsight is small, but we are energized to get it right. I would love it if we had a bit of direct collaboration with the pytorch (or llama-3) teams at meta, to make sure we’re building it well.

Is there a good way to possibly collaborate with the pytorch team, to get advice or feedback or more even contributions? Who might we want to be talking to?

Do you think pytorch needs something like penzai or nnsight?

9 Likes

Here is an example of real-world usage of nnsight. A couple days ago a researcher sent me this small piece of nnsight code to express a small experiment that just reaches in to a model to measure the effect of forcing “on” a set of eight neurons inside specific MLPs inside llama2 7b. (Spoiler is that he believes these neurons play a special role in the model.)

from nnsight import LanguageModel
from rich import print as rprint
from rich.table import Table

// NNsight can be used to wrap any pytorch model.
// In this example we are using a convenience wrapper
// that can reference and load autoregressive LMs by HF ids. 
model = LanguageModel('meta-llama/Llama-2-7b-hf', device_map='cuda', dispatch=True)

sentences = ['Q: cheap\nA:', 'Q: big\nA:', 'Q: fast\nA:']
grouped_neurons =  {12: [7171, 2519, 1786],
                    13: [9469, 7737],
                    14: [3260, 7737, 8894],
                    15: [3528], 16: [10788]}

// An NNsight-wrapped module can used as a context via "trace"
// rather than just as a callable. Code inside the context can
// can read and write submodule inputs and outputs, and it
// will be executed at the right moment while the module is run.

with model.trace(sentences + sentences) as tracer:
    
    for LAYER_NUM in grouped_neurons.keys():

        x = model.model.layers[LAYER_NUM].mlp.down_proj.input
        x[0][0][3:,-1, grouped_neurons[LAYER_NUM]] = 5.0

    output = model.output.save()
    // When we want to see some data outside the context, we say
    // `.save()`. That crosses the boundary between virtual (fake)
    // tensors inside the context that define a computation plan, and
    // concrete tensors outside, which will get data that you "save()".

results_table = Table("query", "clean response (topk)", "intervention response (topk)")

def get_top_tokens(logits, idx, k=3):
    return [(model.tokenizer.decode(y.item()), round(x.item(), 3))
               for x,y in zip(*logits[idx,-1].softmax(-1).topk(k))]

for i in range(len(sentences)):
    clean_res = get_top_tokens(output.value.logits, i)
    interv_res = get_top_tokens(output.value.logits, i+3, 5)
    results_table.add_row(repr(sentences[i]), repr(clean_res), repr(interv_res))

results_table

One of the reasons to use NNsight is that, since it works by defining a deferred execution model, it lets you work with huge models remotely. For example, if you just add remote=True to the tracing context above, then exactly the same code will run the Llama2 model on a remote server, including the customizations, and NNsight will deal with shuttling the input data and output back and forth, as well as the customized computation graph. This is especially useful if you want to switch to a larger model like Llama 70b and you do not have a machine that is large enough to run it locally.

With this idiom you can define arbitrary inference customizations deep in the model, such as adding Lora adapters, patching activations between different runs of a model, or applying various controllable decoding techniques. Since arbitrary computation graphs can be patched into a model, there’s really no limit to the flexibility. We plan to add functionality to the remotable computation graph so that you can define fine-tuning runs (e.g., to train a Lora or other model) or do other more complex things fully remotely.

We think this type of API is probably the right interface for allowing access to much larger (e.g., 500b-class or larger) open models, which will go beyond traditional input-output inference API access. Very large models will be very inconvenient for most people to run locally. But the transparent flexibility of NNsight-style access would allow the same kind of innovation we see with traditional local pytorch use, while allowing the big model to be hosted remotely.

More examples are at https://nnsight.net/ - also linked from there, we have a pretty active discord that you can join to see what users are doing and discussing with the API.

5 Likes

A follow-up.

The NSF has just awarded a project to develop a service based on NNsight, so we will be hiring engineers to implement and scale this approach. https://ndif.us/

A pair of twitter threads that explain the vision and needs are here:

https://x.com/davidbau/status/1787597395741835564

https://x.com/davidbau/status/1785991520988152316