RNN cell
Previously, we described various language models where the conditional probability of token \(\boldsymbol{\mathsf{x}}_t\) depends on a fixed context \(\boldsymbol{\mathsf{x}}_{[t - \tau: t-1]}.\) If we want to incorporate the possible effect of tokens earlier than the given context, we need to increase the context size \(\tau\). For the n-gram model, this would increase the parameters exponentially in \(\tau\). Using embeddings, the MLP network the number of parameters grows as \(O(\tau)\). Finally, using convolutions this decreases to \(O(\log \tau).\) Alternatively, instead of modeling the next token directly in terms of previous tokens, we can use a latent variable that, in principle, stores all previous information up to the previous time step:
where \(\boldsymbol{\mathsf h}_{t-1}\) is a hidden state that stores information up to the time step \(t - 1.\) The hidden state is updated based on the current input and the previous state:
so that \(\boldsymbol{\mathsf h}_{t} = F(\boldsymbol{\mathsf x}_{1}, \ldots, \boldsymbol{\mathsf x}_{t}, \boldsymbol{\mathsf h}_{0})\) for some \(\boldsymbol{\mathsf h}_{0}\) where \(F\) involves recursively applying \(f\) (see Fig. 55). For a sufficiently complex function \(f\), the above latent variable model is not an approximation, since \(\boldsymbol{\mathsf h}_{t}\) can simply store all \(\boldsymbol{\mathsf x}_{1}, \ldots, \boldsymbol{\mathsf x}_{t}\) it has observed so far. In our case, we use fully-connected layers whose complexity can be tuned with its width.
Fig. 55 RNN unit (a) cyclic, and (b) unrolled RNN (essentially a deep MLP with shared weights).
RNN cell. Let each token be represented by vectors \(\boldsymbol{\mathsf{x}}_t \in \mathbb{R}^{d}\) and let \(\boldsymbol{\mathsf{h}}_0 = \boldsymbol{0}.\) Then,
where \(\boldsymbol{\mathsf{U}} \in \mathbb{R}^{d \times h}\), \(\boldsymbol{\mathsf{W}} \in \mathbb{R}^{h \times h}\), and \(\boldsymbol{\mathsf{b}} \in \mathbb{R}^{h}.\) Here \(h\) is the dimensionality of the hidden state. For character-level inputs, \(\boldsymbol{\mathsf{x}}_t\) can be a one-hot vector of length \(|\mathcal{V}|\) so that \(\boldsymbol{\mathsf{U}}\) is \(|\mathcal{V}| \times h\), also acting as the embedding matrix for the tokens[1]. Finally, the state vector is also the output at each step, used by downstream layers of the network. The computation is illustrated in Fig. 56.
Fig. 56 Computational graph of an unrolled simple RNN. Source
Remark. RNNs use the same parameters at each time step, i.e. it is assumed that the dynamics is stationary. Practically, this means that the parameter count does not grow as the sequence length increases, and that the parameters have to time index.
Code implementation
First, we implement the recurrent layer. To implement batch computation, an input \(\boldsymbol{\mathsf{X}}\) has shape \((T, B, d).\) That is, a batch of \(B\) sequences of length \(T\), consisting of vectors in \(\mathbb{R}^{d}.\) Elements of a batch are computed independently, ideally in parallel. For example, \(\boldsymbol{\mathsf{X}}_{[0, :, :]}\) consist of a batch of all vectors at \(t = 0.\) Similarly, \(\boldsymbol{\mathsf{X}}_{[:, 0, :]}\) is one instance of a sequence of vectors. At each step, the layer returns the state vector of shape \((B, h).\) These are stacked to get a tensor of shape \((T, B,h)\) consistent with the input. RNN cell computation can be written as:
outs = []
for t in range(T):
h = torch.tanh(x[t] @ self.U + h @ self.W + self.b)
outs.append(h)
Here the state h
is updated at each step, and the output vector is also set to h
at each step. The initialization of the state vector is not shown, but we typically set it to zero when not specified.
This leads us to the required methods in the base class below. But first, let us define the expected “shape” of an RNN unit.
Base RNN
A recurrent unit is any function that iteratively updates a state based on new sequence input, it may haver other layers for downstream processing at each time step, so we also return an output tensor. We set the following guidelines:
A recurrent unit must have
(inputs_dim, hidden_dim, **kwargs)
as arguments.It’s forward signature is
(x, state=None)
wherex
is “sequence first”, i.e. \((T, B, d)\).It’s forward return format is
outs, state
wherestate
has the expected format as input for the forward function andouts
has shape \((T, B, h)\).
The parameter (d, h)
for an RNN can be read off as transforming the each sequence element from \(\mathbb{R}^d\) to \(\mathbb{R}^h.\) Next, an implementation with expected inputs \((T, B, d)\) is already linear layer friendly. This can be interpreted as processing an entire batch at each time step, instead of entire sequences per batch. Hence, we choose this over the more intuitive “batch first” shape \((B, T, d)\).
Finally, while the output has to be of shape \((T, B, h)\), the state
is more arbitrary. For example, it can be a tuple of tensors (h, c)
. As such, we can call the unit with either (x)
or (x, state=(h0, c0))
. This latter is useful for setting up a warmup state, or continuing inference with another input sequence. The only constraint is that the states are consistently formatted in all parts of a specific implementation. For example, if state
is the current output state, then the unit can be called next with (x, state=state)
without errors.
Show code cell outputs
import torch
import torch.nn as nn
import numpy as np
import random
RANDOM_SEED = 0
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
MPS = torch.backends.mps.is_available()
CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda:0") if CUDA else torch.device("mps") if MPS else torch.device("cpu")
class RNNBase(nn.Module):
"""Base class for recurrent units, e.g. RNN, LSTM, GRU, etc."""
def __init__(self, inputs_dim: int, hidden_dim: int):
super().__init__()
self.hidden_dim = hidden_dim
self.inputs_dim = inputs_dim
def init_state(self, x):
raise NotImplementedError
def compute(self, x, state):
raise NotImplementedError
def forward(self, x, state=None):
state = self.init_state(x) if state is None else state
outs, state = self.compute(x, state)
return outs, state
Implementing the basic RNN cell:
class RNN(RNNBase):
"""Simple RNN unit."""
def __init__(self, inputs_dim: int, hidden_dim: int):
super().__init__(inputs_dim, hidden_dim)
self.U = nn.Parameter(torch.randn(inputs_dim, hidden_dim) / np.sqrt(inputs_dim))
self.W = nn.Parameter(torch.randn(hidden_dim, hidden_dim) / np.sqrt(hidden_dim))
self.b = nn.Parameter(torch.zeros(hidden_dim))
def init_state(self, x):
B = x.shape[1]
h = torch.zeros(B, self.hidden_dim, device=x.device)
return h
def compute(self, x, state):
h = state
T = x.shape[0]
outs = []
for t in range(T):
h = torch.tanh(x[t] @ self.U + h @ self.W + self.b)
outs.append(h)
return torch.stack(outs), h
Remark. It’s important to note that our RNN does not store state outside of forward pass.
Shapes test, i.e. \((T, B, d)\) to \((T, B, h)\) for a network with params \((d, h)\):
B, T, d, h = 32, 10, 30, 5
x = torch.randn(T, B, d)
rnn = RNN(d, h)
outs, state = rnn(x)
assert outs.shape == (T, B, h)
assert state.shape == (B, h)
assert torch.abs(outs[-1] - state).max() < 1e-8
Remark. The PyTorch RNN module has a similar API:
B, T, d, h = 32, 10, 30, 5
rnn_torch = nn.RNN(d, h)
outs, state = rnn_torch(x)
assert outs.shape == (T, B, h)
assert state.shape == (1, B, h)
assert torch.abs(outs[-1] - state).max() < 1e-8
Correctness:
for name, p in rnn.named_parameters():
if name == "b":
p.data.fill_(0.0)
else:
p.data.fill_(1.0)
for name, p in rnn_torch.named_parameters():
if "bias" in name:
p.data.fill_(0.0)
else:
p.data.fill_(1.0)
error = torch.abs(rnn(x)[0] - rnn_torch(x)[0]).max()
print(error)
assert error < 1e-6
tensor(0., grad_fn=<MaxBackward1>)