I’ve written a lot of unions for PyTorch this year. I’ve learned a few things along the way, and I thought a post highlighting this useful-but-little-used C++ construct might be in order. (If you know how to use unions, you may still be interested in the more advanced content toward the middle and end of the post!)
What is a union?
A union is like a struct, but instead of holding ALL
of its members, it holds ANY
single one of them. The main reason to use a union is to save on memory, which might also improve performance. Here’s how you might define one:
union MyUnion {
int x;
void* p;
};
Here, MyUnion
can hold either an int
or a void*
, but not both. Its size is the same as the size of its largest member, which in this case is p
.
C++ union
s are wildly unsafe: unlike in C, it is undefined behavior to read from a member of a union
that isn’t the one that was most recently written, also known as the “active member”. In other words, this cute code is not allowed (though your compiler may let you do it):
int ptrToInt(void* p) {
MyUnion u = {.p = p};
return u.x;
}
You can put class types into unions, but it requires a lot of work:
union IntOrString {
int x;
std::string s;
};
void f() {
IntOrString x = {.s = "hello"};
}
gives us the following discouraging compiler error:
<source>:9:17: error: attempt to use a deleted function
IntOrString x = {.s = "hello"};
^
<source>:5:17: note: destructor of 'IntOrString' is implicitly deleted because variant field 's' has a non-trivial destructor
std::string s;
clang is telling us that we have to implement the destructor for our union manually. Likewise, we would have to implement the other Rule of Five operators (copy/move constructor and assignment) too. The reason is that the compiler doesn’t have any way to know which union member is active and thus which one it needs to destroy (or copy, or move).
The usual way to track the active member of a union is to add a “tag”, producing a tagged union:
class IntOrString {
enum class Tag {
Int,
String
};
Tag tag;
union {
int x;
std::string s;
};
void destroy() {
if (tag == Tag::String) {
// Explicit destructor call!
// Recall that `std::string` is, roughly, a typedef for `std::basic_string<char>`.
s.~basic_string();
}
}
public:
IntOrString() : x(0) {}
~IntOrString() {
destroy();
}
IntOrString(const IntOrString& rhs) : tag(rhs.tag) {
if (tag == Tag::Int) {
x = rhs.x;
} else {
// Placement new (https://en.cppreference.com/w/cpp/language/new#Placement_new) -- explicitly construct a string in `s`
new (&s) std::string(rhs.s);
}
}
IntOrString& operator=(const IntOrString& rhs) {
if (tag == Tag::String) {
if (rhs.tag == Tag::String) {
s = rhs.s;
} else {
s.~basic_string();
x = rhs.x;
}
} else {
if (rhs.tag == Tag::String) {
new (&s) std::string(rhs.s);
} else {
x = rhs.x;
}
}
return *this;
}
// Move ctor/assignment omitted, but very similar.
bool isInt() const {
return tag == Tag::Int;
}
int asInt() const {
CHECK(isInt());
return x;
}
void setInt(int newVal) {
destroy();
x = newVal;
}
/* and similarly for String... */
};
What a pain! We’ll see in the next section how to avoid all this boilerplate for typical tagged unions, and we’ll see in following sections why we might want to write it ourselves anyway.
std::variant
: a type-safe tagged union
C++17 added std::variant
, which makes the process of defining tagged unions easier and safer. Here’s our IntOrString
example rewritten to use std::variant
:
class IntOrString {
std::variant<int, std::string> repr_;
public:
IntOrString() : repr_(0) {}
bool isInt() const {
return std::holds_alternative<int>(repr_);
}
int asInt() const {
return std::get<int>(repr_);
}
void setInt(int newVal) {
repr_ = newVal;
}
/* and similarly for String... */
}
Much easier! std::variant
handles construction, copying, assignment, and destruction for us.
PyTorch is still stuck on C++14 as of this writing, but we have c10::variant
, which is very similar to std::variant
.
Improving on tagged unions
You might reasonably ask why I am bothering to talk about unions at all when std::variant
exists. The reason is that sometimes we don’t want to spend up to 8 bytes on a tag, and in those cases, we need to write unions by hand.
For example, let’s look at ProcessedNodeInputs
, which is a custom “small array” for the PyTorch static runtime. Each ProcessedNode
(which represents a PyTorch operator in the static runtime’s graph IR) has an array of 2-byte indices that refer to its inputs in a global “values” array. It is often the case that operators have no more than 5 inputs, so ProcessedNodeInputs
uses a union to pack up to 5 indices, the array length, and a tag into 12 bytes, while also supporting a heap-allocated array if there are more than 5 indices.
The core of ProcessedNodeInputs
's representation looks like this:
union Repr {
struct InlineRepr {
uint8_t tag = 0x1;
uint8_t size;
uint16_t inputs[kMaxInlineInputs];
};
// Wrapper for a pointer to a heap-allocated fixed-size array; details
// omitted for brevity.
using OutlineRepr = HeapArrayPtr;
InlineRepr inline_repr_{};
OutlineRepr outline_repr_;
};
Notice that our tag byte is inside the union. HeapArrayPtr
will be aligned to at least a 16-byte boundary, so the least significant bit of the memory corresponding to the tag
byte will be 0 if outline_repr_
is active and 1 if inline_repr_
is active. However, we’re not allowed to read inline_repr_.tag
to determine whether inline_repr_
is the active union member unless inline_repr_
actually is the active union member, so how can this possibly work?
The memcpy
loophole
We can take advantage of several “escape hatches” in the C++ rules to inspect the tag byte anyway:
- We are allowed to
reinterpret_cast
(or, equivalently,static_cast
to and fromvoid*
) between any two pointer types. - The aliasing rules are complicated, but we are specifically allowed to dereference a pointer that we type-cast to
char *
orunsigned char *
. - Compilers know about
memcpy
and will optimize small constant-sizememcpy
calls into single load instructions, just as though we had done a simple read of a variable.
Using the memcpy
loophole
With that in mind, here is Repr::is_inline()
:
bool is_inline() const {
uint8_t tag;
std::memcpy(&tag, reinterpret_cast<const uint8_t*>(this), 1);
return (tag & 1) != 0;
}
In short, we can read raw memory wherever we like using reinterpret_cast
and memcpy
.
The rest of ProcessedNodeInputs
is mostly boilerplate of the type we saw before with our IntOrString
example, just using is_inline()
instead of tag checks. It would be pretty awesome if we were able to generalize this approach to provide a “super_variant” template that just needed a way to tap into our custom is_inline()
implementation instead of creating its own tag byte, but I’m not aware of any such template yet.
Expert mode: including types we don’t control in tag-less unions
The memcpy
loophole is quite powerful. If we don’t mind writing non-portable code, we can use it on types we don’t even control by making reasonable assumptions about the way they work. For example, we could reasonably assume that std::shared_ptr
is represented as a pair of pointers, and that either both of those pointers are null or both are not null. As a result, we can write a union that represents either a shared_ptr
or a non-owning raw pointer (see #69579):
template <typename T>
class SingletonOrSharedTypePtr {
union Repr {
std::shared_ptr<T> shared_;
struct {
T* singleton_;
void* unused_;
};
};
};
Again, we can’t read the shared_
part if singleton_
is in use or vice versa, so we add some more magic to our private Repr
union:
union Repr {
/* ... */
// Note that this is a type definition, not a union member!
struct RawRepr {
void* first;
void* nullIfSingleton_;
};
RawRepr rawRepr() const {
RawRepr repr;
memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
return repr;
}
};
Now we can use rawRepr()
to read the contents of our union whenever we like:
bool isSharedAndNonNull() const {
return rawRepr().nullIfSingleton_ != nullptr;
}
void destroy() {
if (isSharedAndNonNull()) {
shared_.~shared_ptr();
}
}
and the rest of the class is straightforward boilerplate on top of this representation.
Omitting Destructor Calls
Finally, I want to cover one more useful quirk: it is permissible to skip calling the destructor of an object if your program “does not rely on the side effects of the destructor”. We can use this to replace the destructor for some type with a more efficient version in a special case:
// If we create a Tensor and don't share it, we can destroy it without paying the cost of a reference count decrement.
// WARNING: This example is simplified; see [ExclusivelyOwned.h](https://github.com/pytorch/pytorch/blob/master/c10/util/ExclusivelyOwned.h) for the real implementation.
class ExclusivelyOwnedTensor {
public:
ExclusivelyOwnedTensor(at::Tensor t)
: t_(std::move(t)) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t.use_count() == 1); }
~ExclusivelyOwnedTensor() {
delete t_.unsafeReleaseTensorImpl();
// No destructor call for t_!
}
private:
// Could also use [std::aligned_storage](https://en.cppreference.com/w/cpp/types/aligned_storage) instead of a union.
union {
char dummy_;
at::Tensor t_;
};
}
Conclusion
Unions are a powerful, sharp tool for saving memory and, potentially, improving performance. Now you can add them to your toolbox and maybe improve on generic library classes like shared_ptr
and SmallVector
the next time you work on performance-critical code.