Unionizing for Profit: How to Exploit the Power of Unions in C++

I’ve written a lot of unions for PyTorch this year. I’ve learned a few things along the way, and I thought a post highlighting this useful-but-little-used C++ construct might be in order. (If you know how to use unions, you may still be interested in the more advanced content toward the middle and end of the post!)

What is a union?

A union is like a struct, but instead of holding ALL of its members, it holds ANY single one of them. The main reason to use a union is to save on memory, which might also improve performance. Here’s how you might define one:

union MyUnion {
  int x;
  void* p;
};

Here, MyUnion can hold either an int or a void*, but not both. Its size is the same as the size of its largest member, which in this case is p.

C++ unions are wildly unsafe: unlike in C, it is undefined behavior to read from a member of a union that isn’t the one that was most recently written, also known as the “active member”. In other words, this cute code is not allowed (though your compiler may let you do it):

int ptrToInt(void* p) {
  MyUnion u = {.p = p};
  return u.x;
}

You can put class types into unions, but it requires a lot of work:

union IntOrString {
  int x;
  std::string s;
};

void f() {
  IntOrString x = {.s = "hello"};
}

gives us the following discouraging compiler error:

<source>:9:17: error: attempt to use a deleted function
    IntOrString x = {.s = "hello"};
                ^
<source>:5:17: note: destructor of 'IntOrString' is implicitly deleted because variant field 's' has a non-trivial destructor
    std::string s;

clang is telling us that we have to implement the destructor for our union manually. Likewise, we would have to implement the other Rule of Five operators (copy/move constructor and assignment) too. The reason is that the compiler doesn’t have any way to know which union member is active and thus which one it needs to destroy (or copy, or move).

The usual way to track the active member of a union is to add a “tag”, producing a tagged union:

class IntOrString {
  enum class Tag {
    Int,
    String
  };

  Tag tag;
  union {
    int x;
    std::string s;
  };

  void destroy() {
    if (tag == Tag::String) {
      // Explicit destructor call!
      // Recall that `std::string` is, roughly, a typedef for `std::basic_string<char>`.
      s.~basic_string();
    }
  } 
 public:
  IntOrString() : x(0) {}
  
  ~IntOrString() {
    destroy();
  }
  
  IntOrString(const IntOrString& rhs) : tag(rhs.tag) {
    if (tag == Tag::Int) {
      x = rhs.x;
    } else {
      // Placement new (https://en.cppreference.com/w/cpp/language/new#Placement_new) -- explicitly construct a string in `s`
      new (&s) std::string(rhs.s);
    }
  }
  
  IntOrString& operator=(const IntOrString& rhs) {
    if (tag == Tag::String) {
      if (rhs.tag == Tag::String) {
        s = rhs.s;
      } else {
        s.~basic_string();
        x = rhs.x;
      }
    } else {
      if (rhs.tag == Tag::String) {
        new (&s) std::string(rhs.s);
      } else {
        x = rhs.x;
      }
    }
    return *this;
  }
  
  // Move ctor/assignment omitted, but very similar.
  
  bool isInt() const {
    return tag == Tag::Int;
  }
  int asInt() const {
    CHECK(isInt());
    return x;
  }
  void setInt(int newVal) {
    destroy();
    x = newVal;
  }
  /* and similarly for String... */
};

What a pain! We’ll see in the next section how to avoid all this boilerplate for typical tagged unions, and we’ll see in following sections why we might want to write it ourselves anyway.

std::variant: a type-safe tagged union

C++17 added std::variant, which makes the process of defining tagged unions easier and safer. Here’s our IntOrString example rewritten to use std::variant:

class IntOrString {
  std::variant<int, std::string> repr_;
 public:
  IntOrString() : repr_(0) {}
  
  bool isInt() const {
    return std::holds_alternative<int>(repr_);
  }
  int asInt() const {
    return std::get<int>(repr_);
  }
  void setInt(int newVal) {
    repr_ = newVal;
  }
  /* and similarly for String... */
}

Much easier! std::variant handles construction, copying, assignment, and destruction for us.

PyTorch is still stuck on C++14 as of this writing, but we have c10::variant, which is very similar to std::variant.

Improving on tagged unions

You might reasonably ask why I am bothering to talk about unions at all when std::variant exists. The reason is that sometimes we don’t want to spend up to 8 bytes on a tag, and in those cases, we need to write unions by hand.

For example, let’s look at ProcessedNodeInputs, which is a custom “small array” for the PyTorch static runtime. Each ProcessedNode (which represents a PyTorch operator in the static runtime’s graph IR) has an array of 2-byte indices that refer to its inputs in a global “values” array. It is often the case that operators have no more than 5 inputs, so ProcessedNodeInputs uses a union to pack up to 5 indices, the array length, and a tag into 12 bytes, while also supporting a heap-allocated array if there are more than 5 indices.

The core of ProcessedNodeInputs's representation looks like this:

union Repr {
    struct InlineRepr {
      uint8_t tag = 0x1;
      uint8_t size;
      uint16_t inputs[kMaxInlineInputs];
    };

    // Wrapper for a pointer to a heap-allocated fixed-size array; details
    // omitted for brevity.
    using OutlineRepr = HeapArrayPtr;

    InlineRepr inline_repr_{};
    OutlineRepr outline_repr_;
};

Notice that our tag byte is inside the union. HeapArrayPtr will be aligned to at least a 16-byte boundary, so the least significant bit of the memory corresponding to the tag byte will be 0 if outline_repr_ is active and 1 if inline_repr_ is active. However, we’re not allowed to read inline_repr_.tag to determine whether inline_repr_ is the active union member unless inline_repr_ actually is the active union member, so how can this possibly work?

The memcpy loophole

We can take advantage of several “escape hatches” in the C++ rules to inspect the tag byte anyway:

  1. We are allowed to reinterpret_cast (or, equivalently, static_cast to and from void*) between any two pointer types.
  2. The aliasing rules are complicated, but we are specifically allowed to dereference a pointer that we type-cast to char * or unsigned char *.
  3. Compilers know about memcpy and will optimize small constant-size memcpy calls into single load instructions, just as though we had done a simple read of a variable.

Using the memcpy loophole

With that in mind, here is Repr::is_inline():

bool is_inline() const {
  uint8_t tag;
  std::memcpy(&tag, reinterpret_cast<const uint8_t*>(this), 1);
  return (tag & 1) != 0;
}

In short, we can read raw memory wherever we like using reinterpret_cast and memcpy.

The rest of ProcessedNodeInputs is mostly boilerplate of the type we saw before with our IntOrString example, just using is_inline() instead of tag checks. It would be pretty awesome if we were able to generalize this approach to provide a “super_variant” template that just needed a way to tap into our custom is_inline() implementation instead of creating its own tag byte, but I’m not aware of any such template yet.

Expert mode: including types we don’t control in tag-less unions

The memcpy loophole is quite powerful. If we don’t mind writing non-portable code, we can use it on types we don’t even control by making reasonable assumptions about the way they work. For example, we could reasonably assume that std::shared_ptr is represented as a pair of pointers, and that either both of those pointers are null or both are not null. As a result, we can write a union that represents either a shared_ptr or a non-owning raw pointer (see #69579):

template <typename T>
class SingletonOrSharedTypePtr {
  union Repr {
    std::shared_ptr<T> shared_;
    struct {
      T* singleton_;
      void* unused_;
    };
  };
};

Again, we can’t read the shared_ part if singleton_ is in use or vice versa, so we add some more magic to our private Repr union:

union Repr {
  /* ... */
  
  // Note that this is a type definition, not a union member!
  struct RawRepr {
    void* first;
    void* nullIfSingleton_;
  };

  RawRepr rawRepr() const {
    RawRepr repr;
    memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
    return repr;
  }
};

Now we can use rawRepr() to read the contents of our union whenever we like:

bool isSharedAndNonNull() const {
  return rawRepr().nullIfSingleton_ != nullptr;
}

void destroy() {
  if (isSharedAndNonNull()) {
    shared_.~shared_ptr();
  }
}

and the rest of the class is straightforward boilerplate on top of this representation.

Omitting Destructor Calls

Finally, I want to cover one more useful quirk: it is permissible to skip calling the destructor of an object if your program “does not rely on the side effects of the destructor”. We can use this to replace the destructor for some type with a more efficient version in a special case:

// If we create a Tensor and don't share it, we can destroy it without paying the cost of a reference count decrement.
// WARNING: This example is simplified; see [ExclusivelyOwned.h](https://github.com/pytorch/pytorch/blob/master/c10/util/ExclusivelyOwned.h) for the real implementation.
class ExclusivelyOwnedTensor {
 public:
  ExclusivelyOwnedTensor(at::Tensor t)
    : t_(std::move(t)) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t.use_count() == 1); }

  ~ExclusivelyOwnedTensor() {
    delete t_.unsafeReleaseTensorImpl();
    // No destructor call for t_!
  }
 private:
  // Could also use [std::aligned_storage](https://en.cppreference.com/w/cpp/types/aligned_storage) instead of a union.
  union {
    char dummy_;
    at::Tensor t_;
  };
}

Conclusion

Unions are a powerful, sharp tool for saving memory and, potentially, improving performance. Now you can add them to your toolbox and maybe improve on generic library classes like shared_ptr and SmallVector the next time you work on performance-critical code.

4 Likes

You need to be careful with the memcpy loophole. Relying on padding is playing with fire because different platforms have different ABIs and memory layouts.

Saving memory is usually great for performance, so I appreciate your work. But be careful! There are so many places in PyTorch where we could save data & code size… A lot of the generated code is quite bloated, for example.

  1. Do we support any platforms with ABIs other than ILP32 and LP64?
  2. I don’t think the memcpy loophole relies on padding, just that std::shared_ptr has the same representation as a struct containing two pointers to void. If some weird platform puts padding between the two pointers, I think that the example should still work. I probably should have noted that the real code has a static_assert to make sure we’re not building for a platform where things don’t seem to line up correctly: [PyTorch][JIT] Don't refcount Type singletons by swolchok · Pull Request #69579 · pytorch/pytorch · GitHub