How to trace torch.autograd.backward or torch.autograd.grad?

Thanks shuokay! The default_partition func is just what I need.

Btw do you have any idea if I only want to trace the backward graph (instead of tracing a joint one and partitioning it)?