torch.js Proposal
Hi all! I’m writing this post to hopefully start some discussion about formalizing an API for PyTorch in JavaScript.
Why JavaScript?
For one, JavaScript is the most commonly used programming language on the planet. (source)
Second, JavaScript enables zero-download demonstrations that will allow researchers to share their models and results more easily with a large audience. Further, JavaScript runs client-side, which will reduce the cost associated with hosting model-running servers.
To motivate the idea, here’s a quick in-browser demo https://mnist-demo.glitch.me that would be implemented as follows:
export class Net extends nn.Module {
constructor() {
super();
this.convs = nn.Sequential(
nn.Conv2d(3, 16, 5, {stride: 1, padding: 2}),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Conv2d(16, 32, 5, {stride: 1, padding: 2}),
nn.MaxPool2d(2),
nn.ReLU()
);
this.mlp = nn.Linear(32, 10);
}
forward(input) {
let x = input;
x = this.convs(x);
x = x.view([x.size(0), -1]);
return this.mlp(x);
}
}
const model = new Net();
await model.load_state_dict("model.ptjs");
All of the features involved are possible today, with loop_tool as a kernel generator and built-in fetch APIs for downloading weights. Below is an in-depth look at the API proposal.
Language Specification
Generally, the proposal is to copy the Python API as closely as possible: use the same
operator names, arguments and argument names (for keywords). A JavaScript API may entail relatively small changes to the more conventional “Pythonic” use of PyTorch.
Async programming
JavaScript has embraced an asynchronous programming model, which should be adopted. This opens the door for many types of backend implementations (including those that either JIT compile or download operator implementations on the fly).
Prefer
const x = torch.randn(128);
const data = await x.data(); // (Float32Array), must be run in an async function
Over
const x = torch.randn(128);
const data = x.data();
In-place operations
JavaScript ecosystems typically adopt a more functional approach. It makes sense to deprioritize the implementation of these operations, but they are not fundamentally incompatible.
Prefer
const x = torch.randn(128);
const y = x.relu();
Over
const x = torch.randn(128);
x.relu_();
Keyword Arguments
There are no keyword arguments in JavaScript. Instead, we can use objets as arguments.
Impossible:
torch.nn.functional.conv2d(input, weight, stride=3);
Instead
torch.nn.functional.conv2d(input, weight, {stride: 3});
Operator overloads
There is no operator overloading JavaScript. Default Tensor.* operations can be used instead.
Impossible:
const x = torch.randn(128);
const y = torch.randn(128);
const z = x + y;
Instead
const x = torch.randn(128);
const y = torch.randn(128);
const z = x.add(y);
There exist Babel plugins to enable this functionality for users with a transpilation step in their workflow.
Imports
JavaScript now has support for “modules” with more Pythonic import syntax. This should be used.
import * as torch from "./torch.mjs";
const nn = torch.nn;
torch.nn.Module
s
This API can be adopted almost exactly modulo language syntax:
export class Net extends nn.Module {
constructor() {
super();
// …
}
forward(x) {
}
}
Checkpoints
The checkpoint format may need to be converted to a JS specific format offline, but the state_dict
semantics should be preserved.
await model.load_state_dict(“model.ptjs”);
Or
const state_dict = await fetch(“model.ptjs”).then((response) =>
response.json();
);
model.load_state_dict(state_dict);
Operator coverage
PyTorch has many operators, not all of which will be needed. TensorFlow.js, a successful product in the same space, has only implemented ~160 ops with varying levels of coverage per backend:
Implementation Notes
Despite a JS frontend, implementing compute-heavy operators in pure JavaScript is likely insufficient for general usability. There are two good[1] target execution engines in browsers with varying levels of support coverage.
-
WebAssembly
- Nearly 100% coverage (all major browsers support it)
- SIMD (4x speedup), 80% coverage
- FMA (2x speedup), 0% coverage
-
WebGPU
- Much faster than wasm
- 0% coverage
[1] “Good” - semantically capable of expressing common machine learning workloads. WebGL is possible to use, but would likely be more effort than it’s worth.