In summary, earlier this half I wrote a wrapper tensor subclass for a differential privacy application. This post discusses:
- Some context on how PySyft implements differential privacy
- A short introduction on the variants of tensor subclassing
- Discussion of what went well: functionalization, subclassing, autograd, decompositions
- Discussion of what was tricky: API differences with numpy, e.g. type promotion (this applies to backend implementations in general)
Some context on the application
This work was in collaboration with OpenMined an organization that is building tools to enable differentially private learning. The tools would allow researchers to safely leverage data that previously was not accessible due to privacy protections, for example medical records.
One such tool is PySyft, a stack that enables training on protected data. It roughly functions as follows:
- Data is kept on a trusted server that clients do not have direct access to
- Clients work with the data via a remote handle which I will call “RPC tensor”. Operations performed on the RPC tensor are communicated to the data owner server where they are computed remotely.
- Since clients don’t access the data, only the shape of the tensor is known to the client by default. Whenever computations are done, the client only sees shape propagation happening
- To see the actual values of the output tensor, clients must “publish” their tensor handles by applying what is called a “privacy budget”, which clients get a fixed amount of. More trustworthy clients have larger privacy budgets.
- When the result is viewed, a certain amount of noise is added. Clients can request to use more of their privacy budget to reduce the amount of added noise. For the same amount of privacy budget consumed, the amount of noise added depends on how revealing the computed statistic is to a given row of data.
The task
The task was to integrate PySyft’s RPC tensor with PyTorch so that it could better interface with PyTorch’s features including autograd (prior to this they were rolling their own). More concretely, the way we do this is to wrap RPC tensor in a wrapper subclass.
A little background on subclassing and why a wrapper subclass is the right choice:
- Why not use
__torch_function__
? It is important to interpose at the Python dispatch level because Python dispatch sits below autograd, so intercepting operations after the autograd layer has already been applied lets us use autograd as-is. (Unrelated: torch function can still be useful though if you want to override autograd.) - Why use a wrapper subclass (“has-a”) vs. inheritance (“is-a”)? The RPC tensor itself does not hold any data. A wrapper tensor subclass handles this case by setting the outer tensor to be meta tensor.
- For a bit more context on how to decide between various subclassing options see Alban’s colab (from a few patches ago, but mostly up to date)
What went well
I found that most things went pretty smoothly. A couple different moving pieces: functionalization, subclassing, autograd all worked well together:
- PySyft’s backend does not support views, but functionalization worked well here. (Though most people only used functionalization with functorch, I only ran into this issue while training with it in core - quickly fixed by Ed)
- Subclassing worked smoothly with autograd. There was a lot of prior work done here already (functorch uses c++ wrapper subclasses and there is a good set of examples from subclass zoo), so maybe it shouldn’t be surprising that it already works well, especially since we weren’t doing anything super exotic.
- It is simple to integrate decompositions with subclassing. Sherlock and others are working on improve the UX for decompositions in general. For subclassing specifically, here is my attempt at a reusable wrapper subclass that allows users to work with subclassing.
Challenges
There were still a few challenges to overcome:
- The default behaviors were mostly the right ones, though there was a case where some magic was needed to get stuff to work:
- torch function needed to be explicitly disabled when using torch dispatch with (this should be fixed now)
- PySyft uses NumPy to compute ops in its backend, so differences between numpy and torch needed to be handled
- Because ops with the same name do different things, e.g. torch.expand(size) = np.broadcast_to(a, desired_size), additional clarification was needed when communicating to backend implementers about what ops need to be implemented
- Type promotion differences: one discrepancy is that NumPy binary ops between int64 and float32 promote to float64, torch prefers to stay in float32. An additional layer of logic needs to be added to handle this difference. (People implementing prims won’t run into this issue though)
- Array API actaully indicates the mixed type promotion is implementation defined (so NumPy and PyTorch technically conform). As a deep learning library may prefer to default to float32 (and won’t be following numpy’s default)
Project status
We were able to develop a subclass wrapper for a RPC tensor capable of running a full training loop. Iterations remains slow due to communication overhead as tensors need to be published every iteration, but this can be mitigated by larger batch sizes. More privacy budget need to be applied for correctness of forward and gradient to pass with high precision, so accuracy can be an issue. Code can be found here.
Conclusion
Overall python subclassing seemed to work smoothly for this application and I would recommend others building something similar to also try it out.
Something things that would’ve been helpful:
- There are no plans to make pytorch conform exactly to numpy’s behavior as that would be prohibitively bc-breaking. However, maybe some things could be done to soften the edge on this. Could numpy/pytorch differences have been documented in a central place? We could even go a step further and have a from torch import np namespace or decomposition table specifically to convert numpy ops.
- There are a lot of great examples in subclass zoo that illustrate the various ways one can use subclassing, but it may be hard for a new comer to find that the particular example they are looking. An official “tour through the zoo” would be helpful.
Some additional notes:
- My experience in this post was mostly based on the state of subclassing from a couple patches back. The state of tensor subclassing has developed since.
- If one wants to interpose at the python dispatch level, subclassing is not the only option, for many use cases it can be better to use modes.
Thanks Ishan Mishra, Rasswanth S., and Andrew Trask from OpenMined for onboarding me to their project and their excellent support during our collaboration.
Thanks Alban for connecting me with folks from OpenMined and guidance on the project