# Per sample gradients using function transforms not working for RNN

This is the issue I raised in https://github.com/pytorch/tutorials/issues/2566. svekars suggested that I post it here.

Hello, I’m working on a optimization algorithm that requires computing the per sample gradients. Assuming the batch size is $N$ and the number of model parameters is $M$, I want to calculate $\partial \log p(x^{(i)};\theta)/\partial \theta_j$, which is an $N \times M$ matrix. I found the PER-SAMPLE-GRADIENTS tutorial and began my own experiments. As a proof of concept, I defined a generative model with a tractable likelihood, such as MADE (Masked Autoencoder for Distribution Estimation), PixelCNN, RNN, etc., and sepcified the log_prob and sample methods. I utilized the function transforms methods mentioned in the tutorial, but currently, it only works for MADE (I believed it works for NADE and PixelCNN too, since these models need only one forward pass to calculate the log likelihood). Below, I’ve provided my code snippets, and I’m interested in figuring out why it’s not working for RNN. Making it work for RNN would significantly reduce the number of parameters for my research purpose. Thank you!

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

def __init__(self, n=10, device='cpu', *args, **kwargs):
super().__init__()
self.n = n
self.device = device

self.weight = nn.Parameter(torch.randn(self.n, self.n) / math.sqrt(self.n))
self.bias = nn.Parameter(torch.zeros(self.n))

def pred_logits(self, x):
return F.linear(x, self.mask * self.weight, self.bias)

def forward(self, x):
logits = self.pred_logits(x)
log_probs = - F.binary_cross_entropy_with_logits(logits, x, reduction='none')
return log_probs.sum(-1)

def sample(self, batch_size):
x = torch.zeros(batch_size, self.n, dtype=torch.float, device=self.device)
for i in range(self.n):
logits = self.pred_logits(x)[:, i]
x[:, i] = torch.bernoulli(torch.sigmoid(logits))
return x

class GRUModel(nn.Module):
'''GRU for density estimation'''

def __init__(self, n=10, input_size=2, hidden_size=8, device='cpu'):
super().__init__()
self.n = n
self.input_size = input_size  # input_size=2 when x is binary
self.hidden_size = hidden_size
self.device = device
self.gru_cell = nn.GRUCell(self.input_size, self.hidden_size)
self.fc_layer = nn.Linear(self.hidden_size, 1)

def pred_logits(self, x, h=None):
x = torch.stack([x, 1 - x], dim=1)  # 1 -> (1, 0), 0 -> (0, 1), (batch_size, 2)
h_next = self.gru_cell(x, h)  # h_{i+1}
logits = self.fc_layer(h_next).squeeze(1)
return h_next, logits

def forward(self, x):
log_prob_list = []
x = torch.cat([torch.zeros(x.shape[0], 1, dtype=torch.float, device=self.device), x], dim=1)  # cat x_0
h = torch.zeros(x.shape[0], self.hidden_size, dtype=torch.float, device=self.device)  # h_0
for i in range(self.n):
h, logits = self.pred_logits(x[:, i], h)
log_prob = - F.binary_cross_entropy_with_logits(logits, x[:, i + 1], reduction='none')
log_prob_list.append(log_prob)

def sample(self, batch_size):
x = torch.zeros(batch_size, self.n + 1, dtype=torch.float, device=self.device)
for i in range(self.n):
h, logits = self.pred_logits(x[:, i], h=None if i == 0 else h)
x[:, i + 1] = torch.bernoulli(torch.sigmoid(logits))
return x[:, 1:]

if __name__ == '__main__':
# model = GRUModel()

# Sample from the generative model
samples = model.sample(128)

# Then I use the function transforms methods mentioned in the tutorial
# to calculate the per sample mean
from torch.func import functional_call, grad, vmap
params = {k: v.detach() for k, v in model.named_parameters()}

def loss_fn(log_probs):
return log_probs.mean(0)

def compute_loss(params, sample):
batch = sample.unsqueeze(0)
log_prob = functional_call(model, (params,), (batch,))
loss = loss_fn(log_prob)
return loss

The above code works for MADE (I also check the values of gradients, they are correct!)
However, when I use model = GRUModel(), an error arises:

Traceback (most recent call last):
File "per_sample_grads.py", line 100, in <module>
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
return _flat_vmap(
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1380, in wrapper
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1245, in wrapper
output = func(*args, **kwargs)
File "per_sample_grads.py", line 94, in compute_loss
log_prob = functional_call(model, (params,), (batch,))
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
return nn.utils.stateless._functional_call(
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/utils/stateless.py", line 262, in _functional_call
return module(*args, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "per_sample_grads.py", line 63, in forward
h, logits = self.pred_logits(x[:, i], h)
File "per_sample_grads.py", line 54, in pred_logits
h_next = self.gru_cell(x, h)  # h_{i+1}
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 1327, in forward
ret = _VF.gru_cell(
RuntimeError: output with shape [1, 8] doesn't match the broadcast shape [128, 1, 8]