Proposal: torch.js, a JavaScript frontend for PyTorch

torch.js Proposal

Hi all! I’m writing this post to hopefully start some discussion about formalizing an API for PyTorch in JavaScript.

Why JavaScript?

For one, JavaScript is the most commonly used programming language on the planet. (source)

Second, JavaScript enables zero-download demonstrations that will allow researchers to share their models and results more easily with a large audience. Further, JavaScript runs client-side, which will reduce the cost associated with hosting model-running servers.

To motivate the idea, here’s a quick in-browser demo https://mnist-demo.glitch.me that would be implemented as follows:

export class Net extends nn.Module {
  constructor() {
    super();
    this.convs = nn.Sequential(
      nn.Conv2d(3, 16, 5, {stride: 1, padding: 2}),
      nn.MaxPool2d(2),
      nn.ReLU(),
      nn.Conv2d(16, 32, 5, {stride: 1, padding: 2}),
      nn.MaxPool2d(2),
      nn.ReLU()
    );
    this.mlp = nn.Linear(32, 10);
  }
  
  forward(input) {
    let x = input;
    x = this.convs(x);
    x = x.view([x.size(0), -1]);
    return this.mlp(x);
  }
}

const model = new Net();
await model.load_state_dict("model.ptjs");

All of the features involved are possible today, with loop_tool as a kernel generator and built-in fetch APIs for downloading weights. Below is an in-depth look at the API proposal.

Language Specification

Generally, the proposal is to copy the Python API as closely as possible: use the same
operator names, arguments and argument names (for keywords). A JavaScript API may entail relatively small changes to the more conventional “Pythonic” use of PyTorch.

Async programming

JavaScript has embraced an asynchronous programming model, which should be adopted. This opens the door for many types of backend implementations (including those that either JIT compile or download operator implementations on the fly).

Prefer

const x = torch.randn(128);
const data = await x.data(); // (Float32Array), must be run in an async function

Over

const x = torch.randn(128);
const data = x.data();

In-place operations

JavaScript ecosystems typically adopt a more functional approach. It makes sense to deprioritize the implementation of these operations, but they are not fundamentally incompatible.

Prefer

const x = torch.randn(128);
const y = x.relu();

Over

const x = torch.randn(128);
x.relu_();

Keyword Arguments

There are no keyword arguments in JavaScript. Instead, we can use objets as arguments.

Impossible:

torch.nn.functional.conv2d(input, weight, stride=3);

Instead

torch.nn.functional.conv2d(input, weight, {stride: 3});

Operator overloads

There is no operator overloading JavaScript. Default Tensor.* operations can be used instead.

Impossible:

const x = torch.randn(128);
const y = torch.randn(128);
const z = x + y;

Instead

const x = torch.randn(128);
const y = torch.randn(128);
const z = x.add(y);

There exist Babel plugins to enable this functionality for users with a transpilation step in their workflow.

Imports

JavaScript now has support for “modules” with more Pythonic import syntax. This should be used.

import * as torch from "./torch.mjs";
const nn = torch.nn;

torch.nn.Modules

This API can be adopted almost exactly modulo language syntax:

export class Net extends nn.Module {
  constructor() {
    super();
    // …
  }
  forward(x) {
  }
}

Checkpoints

The checkpoint format may need to be converted to a JS specific format offline, but the state_dict semantics should be preserved.

await model.load_state_dict(“model.ptjs”);

Or

const state_dict = await fetch(“model.ptjs”).then((response) =>
  response.json();
);

model.load_state_dict(state_dict);

Operator coverage

PyTorch has many operators, not all of which will be needed. TensorFlow.js, a successful product in the same space, has only implemented ~160 ops with varying levels of coverage per backend:

Implementation Notes

Despite a JS frontend, implementing compute-heavy operators in pure JavaScript is likely insufficient for general usability. There are two good[1] target execution engines in browsers with varying levels of support coverage.

  1. WebAssembly

    1. Nearly 100% coverage (all major browsers support it)
    2. SIMD (4x speedup), 80% coverage
    3. FMA (2x speedup), 0% coverage
  2. WebGPU

    1. Much faster than wasm
    2. 0% coverage

[1] “Good” - semantically capable of expressing common machine learning workloads. WebGL is possible to use, but would likely be more effort than it’s worth.

3 Likes

Hey I think this is very cool! Having the ability for quick demo is definitely very desirable.

2 questions I have:

  1. What is the interoperability story with python look like (or is it even a goal)? I am thinking most researcher would want to author / train in python and export their model to be loaded in javascript for demo.
  2. Related, looks like in this proposal all kernels are newly generated, do you foresee issues like lower coverage than ATen ops or slightly different semantics so that Python trained model won’t have the same accuracy when running in js?

By interop do you mean weights being shared or the ability to use models that have been written in Python and are automatically converted to JavaScript? I think the former is straight forward to support and should be a listed goal but the latter would require some ability to transpile. It’s not a listed goal, but if there is demand I believe torch.fx would be useful here!

Re: kernels - much like TF.js all effort should be made to maintain full compatibility. The implementations will certainly need to be rewritten for performance, but the testing could probably be automatic. I don’t foresee intrinsic issues with compatibility but it would certainly require a new surface to be maintained. I think starting with a small subset of operators would be a good way to test out the burden of maintaining this new surface.

1 Like

Bram, how about binding JS to cpp? Or such bindings are not possible in all js engines.

Also is autograd supported automatically?

JS <-> C++ interop can only be achieved in the browser through compiling to WebAssembly (e.g. via emscripten). Directly using all of PyTorch’s C++ code will prove more difficult, as the compiled WebAssembly will likely be too large to comfortably send down over the wire.

This is an API spec proposal, and I believe support for autograd should be included in the spec, yes! (i.e. tensor.backward() should be available)

Could you split it into separate modules and lazily load just the code that’s required?

Otherwise, some languages that can cross-compile to WebAssembly do “tree shaking” (removing unreachable code) as part of compilation, sometimes including libraries too. For example Microsoft’s Blazor (which compiles C# to WebAssembly) does this both for the app code and for imported Microsoft and third-party libraries. Of course, that requires a compilation step of some sort, which removes one of the benefits of using JS (fast prototyping).

Could you split it into separate modules and lazily load just the code that’s required?

This is a good point, and for modularized kernel/operator level implementations it seems straightforward.

More complex “dynamic” tree-shaking may be required for general core runtime (e.g. splitting out autograd and only loading it if folks require gradients). There are likely some automated ways to do this for arbitrarily compiled WebAssembly, which I’ve explored a bit here.

I’m looking very much forward to torch.js! Thank you for doing and posting this.

How does the problem domain here compare to those that the mobile efforts faced? Learning from them could be helpful, but to my mind even more importantly, having some consistency between the various “run somewhere else” options would likely improve the experience for users a lot.

Best regards

Thomas

This is definitely an important consideration, thanks for highlighting!

The largest overlap with mobile is probably device compute constraints: maybe ~1Tflop of compute. I suspect this shouldn’t have too much bearing on API design.

From what I understand, the mobile effort has largely been export based. This proposal is predominantly about cloning the PyTorch API over to a new language (much like the libtorch/PyTorch C++ API). Ideally this would empower a new set of users, ones that do not use Python to build their products.

That being said - there is a latent question of “what about Python users who want to export and run models on web?” I believe solutions for that can be built in parallel (perhaps using the same backend :slight_smile: ).

1 Like

I’d like the proposal to be fleshed out further wrt the non-technical parts.

  1. Who will build it?
  2. Who will maintain it?
  3. What are the hypothetical timelines?
  4. How does this converge or diverge with the PyTorch Python frontend in the long horizon?

I think it’s important to consider this as an API specification proposal.

The web ecosystem typically functions with many (heavily-used) implementations for the same specification. We see that with browser features (three major browsers each with their own implementations and support across the web)[0] and WebAssembly engines (which has been implemented many times)[1].

Adoption of this common “specification first” model will allow flexibility and speed as it reduces the burden on the creation of a single implementation. It’s also a great way to find customers before investing significant time in driving compatibility and performance.

Re 1./2. - I’d expect building and maintenance of the specification to be done in the PyTorch open source org in a way similar to WebAssembly’s spec page[2]. Implementation is not an immediate requirement.

Re 3.

That being said, I think a “reference implementation” would be straight forward to implement with 1-2 engineers over ~4-6 months (depending on how big the spec ends up getting). Given the current state of WASM kernels (TF.js has a bunch, loop_tool can do a good job on the less exotic ones), I imagine it will be reasonably efficient.

Re 4.

It’s pretty much a certainty that a web standard for eager-mode machine learning will spring up when WebGPU comes online. This proposal is to get ahead of that and build out an ecosystem around an API that doesn’t diverge significantly from PyTorch API. Full semantic convergence is a non-goal, the languages are different. Operator coverage will likely begin to converge as adoption increases.

Does that seem reasonable?

[0] Standards - W3C, https://caniuse.com
[1] Roadmap - WebAssembly
[2] GitHub - WebAssembly/spec: WebAssembly specification, reference interpreter, and test suite.

(Note: there is a W3C proposal for an NN API - it’s based on graph execution Web Neural Network API)