Per sample gradients using function transforms not working for RNN

This is the issue I raised in https://github.com/pytorch/tutorials/issues/2566. svekars suggested that I post it here.


Hello, I’m working on a optimization algorithm that requires computing the per sample gradients. Assuming the batch size is $N$ and the number of model parameters is $M$, I want to calculate $\partial \log p(x^{(i)};\theta)/\partial \theta_j$, which is an $N \times M$ matrix. I found the PER-SAMPLE-GRADIENTS tutorial and began my own experiments. As a proof of concept, I defined a generative model with a tractable likelihood, such as MADE (Masked Autoencoder for Distribution Estimation), PixelCNN, RNN, etc., and sepcified the log_prob and sample methods. I utilized the function transforms methods mentioned in the tutorial, but currently, it only works for MADE (I believed it works for NADE and PixelCNN too, since these models need only one forward pass to calculate the log likelihood). Below, I’ve provided my code snippets, and I’m interested in figuring out why it’s not working for RNN. Making it work for RNN would significantly reduce the number of parameters for my research purpose. Thank you!

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)


class MADE(nn.Module):
    '''A simple one-layer MADE (Masked Autoencoder for Distribution Estimation)'''

    def __init__(self, n=10, device='cpu', *args, **kwargs):
        super().__init__()
        self.n = n
        self.device = device

        self.weight = nn.Parameter(torch.randn(self.n, self.n) / math.sqrt(self.n))
        self.bias = nn.Parameter(torch.zeros(self.n))
        mask = torch.tril(torch.ones(self.n, self.n), diagonal=-1)
        self.register_buffer('mask', mask)

    def pred_logits(self, x):
        return F.linear(x, self.mask * self.weight, self.bias)

    def forward(self, x):
        logits = self.pred_logits(x)
        log_probs = - F.binary_cross_entropy_with_logits(logits, x, reduction='none')
        return log_probs.sum(-1)

    @torch.no_grad()
    def sample(self, batch_size):
        x = torch.zeros(batch_size, self.n, dtype=torch.float, device=self.device)
        for i in range(self.n):
            logits = self.pred_logits(x)[:, i]
            x[:, i] = torch.bernoulli(torch.sigmoid(logits))
        return x


class GRUModel(nn.Module):
    '''GRU for density estimation'''

    def __init__(self, n=10, input_size=2, hidden_size=8, device='cpu'):
        super().__init__()
        self.n = n
        self.input_size = input_size  # input_size=2 when x is binary
        self.hidden_size = hidden_size
        self.device = device
        self.gru_cell = nn.GRUCell(self.input_size, self.hidden_size)
        self.fc_layer = nn.Linear(self.hidden_size, 1)

    def pred_logits(self, x, h=None):
        x = torch.stack([x, 1 - x], dim=1)  # 1 -> (1, 0), 0 -> (0, 1), (batch_size, 2)
        h_next = self.gru_cell(x, h)  # h_{i+1}
        logits = self.fc_layer(h_next).squeeze(1)
        return h_next, logits

    def forward(self, x):
        log_prob_list = []
        x = torch.cat([torch.zeros(x.shape[0], 1, dtype=torch.float, device=self.device), x], dim=1)  # cat x_0
        h = torch.zeros(x.shape[0], self.hidden_size, dtype=torch.float, device=self.device)  # h_0
        for i in range(self.n):
            h, logits = self.pred_logits(x[:, i], h)
            log_prob = - F.binary_cross_entropy_with_logits(logits, x[:, i + 1], reduction='none')
            log_prob_list.append(log_prob)
        return torch.stack(log_prob_list, dim=1).sum(dim=1)

    @torch.no_grad()
    def sample(self, batch_size):
        x = torch.zeros(batch_size, self.n + 1, dtype=torch.float, device=self.device)
        for i in range(self.n):
            h, logits = self.pred_logits(x[:, i], h=None if i == 0 else h)
            x[:, i + 1] = torch.bernoulli(torch.sigmoid(logits))
        return x[:, 1:]


if __name__ == '__main__':
    model = MADE()
    # model = GRUModel()

    # Sample from the generative model
    samples = model.sample(128)

    # Then I use the function transforms methods mentioned in the tutorial
    # to calculate the per sample mean
    from torch.func import functional_call, grad, vmap
    params = {k: v.detach() for k, v in model.named_parameters()}

    def loss_fn(log_probs):
        return log_probs.mean(0)

    def compute_loss(params, sample):
        batch = sample.unsqueeze(0)
        log_prob = functional_call(model, (params,), (batch,))
        loss = loss_fn(log_prob)
        return loss

    ft_compute_grad = grad(compute_loss)
    ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0))
    ft_per_sample_grads = ft_compute_sample_grad(params, samples)

    print(ft_per_sample_grads)

The above code works for MADE (I also check the values of gradients, they are correct!)
However, when I use model = GRUModel(), an error arises:

Traceback (most recent call last):
  File "per_sample_grads.py", line 100, in <module>
    ft_per_sample_grads = ft_compute_sample_grad(params, samples)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
    return _flat_vmap(
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1380, in wrapper
    results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1245, in wrapper
    output = func(*args, **kwargs)
  File "per_sample_grads.py", line 94, in compute_loss
    log_prob = functional_call(model, (params,), (batch,))
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
    return nn.utils.stateless._functional_call(
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/utils/stateless.py", line 262, in _functional_call
    return module(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "per_sample_grads.py", line 63, in forward
    h, logits = self.pred_logits(x[:, i], h)
  File "per_sample_grads.py", line 54, in pred_logits
    h_next = self.gru_cell(x, h)  # h_{i+1}
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 1327, in forward
    ret = _VF.gru_cell(
RuntimeError: output with shape [1, 8] doesn't match the broadcast shape [128, 1, 8]