RNN cell

Previously, we described various language models where the conditional probability of token \(\boldsymbol{\mathsf{x}}_t\) depends on a fixed context \(\boldsymbol{\mathsf{x}}_{[t - \tau: t-1]}.\) If we want to incorporate the possible effect of tokens earlier than the given context, we need to increase the context size \(\tau\). For the n-gram model, this would increase the parameters exponentially in \(\tau\). Using embeddings, the MLP network the number of parameters grows as \(O(\tau)\). Finally, using convolutions this decreases to \(O(\log \tau).\) Alternatively, instead of modeling the next token directly in terms of previous tokens, we can use a latent variable that, in principle, stores all previous information up to the previous time step:

\[ p(\boldsymbol{\mathsf x}_{t} \mid \boldsymbol{\mathsf x}_{1}, \ldots, \boldsymbol{\mathsf x}_{t-1}) \approx p(\boldsymbol{\mathsf x}_{t} \mid \boldsymbol{\mathsf h}_{t-1}) \]

where \(\boldsymbol{\mathsf h}_{t-1}\) is a hidden state that stores information up to the time step \(t - 1.\) The hidden state is updated based on the current input and the previous state:

\[ \boldsymbol{\mathsf h}_{t} = f(\boldsymbol{\mathsf x}_{t}, \boldsymbol{\mathsf h}_{t-1}) \]

so that \(\boldsymbol{\mathsf h}_{t} = F(\boldsymbol{\mathsf x}_{1}, \ldots, \boldsymbol{\mathsf x}_{t}, \boldsymbol{\mathsf h}_{0})\) for some \(\boldsymbol{\mathsf h}_{0}\) where \(F\) involves recursively applying \(f\) (see Fig. 55). For a sufficiently complex function \(f\), the above latent variable model is not an approximation, since \(\boldsymbol{\mathsf h}_{t}\) can simply store all \(\boldsymbol{\mathsf x}_{1}, \ldots, \boldsymbol{\mathsf x}_{t}\) it has observed so far. In our case, we use fully-connected layers whose complexity can be tuned with its width.


../../../_images/04-rnn.svg

Fig. 55 RNN unit (a) cyclic, and (b) unrolled RNN (essentially a deep MLP with shared weights).

RNN cell. Let each token be represented by vectors \(\boldsymbol{\mathsf{x}}_t \in \mathbb{R}^{d}\) and let \(\boldsymbol{\mathsf{h}}_0 = \boldsymbol{0}.\) Then,

\[\begin{split} \begin{aligned} \boldsymbol{\mathsf{h}}_t &= \tanh(\boldsymbol{\mathsf{x}}_t \boldsymbol{\mathsf{U}} + \boldsymbol{\mathsf{h}}_{t-1} \boldsymbol{\mathsf{W}} + \boldsymbol{\mathsf{b}}) \\ \boldsymbol{\mathsf{y}}_t &= \boldsymbol{\mathsf{h}}_t \end{aligned} \end{split}\]

where \(\boldsymbol{\mathsf{U}} \in \mathbb{R}^{d \times h}\), \(\boldsymbol{\mathsf{W}} \in \mathbb{R}^{h \times h}\), and \(\boldsymbol{\mathsf{b}} \in \mathbb{R}^{h}.\) Here \(h\) is the dimensionality of the hidden state. For character-level inputs, \(\boldsymbol{\mathsf{x}}_t\) can be a one-hot vector of length \(|\mathcal{V}|\) so that \(\boldsymbol{\mathsf{U}}\) is \(|\mathcal{V}| \times h\), also acting as the embedding matrix for the tokens[1]. Finally, the state vector is also the output at each step, used by downstream layers of the network. The computation is illustrated in Fig. 56.


../../../_images/04-simple-rnn.svg

Fig. 56 Computational graph of an unrolled simple RNN. Source

Remark. RNNs use the same parameters at each time step, i.e. it is assumed that the dynamics is stationary. Practically, this means that the parameter count does not grow as the sequence length increases, and that the parameters have to time index.

Code implementation

First, we implement the recurrent layer. To implement batch computation, an input \(\boldsymbol{\mathsf{X}}\) has shape \((T, B, d).\) That is, a batch of \(B\) sequences of length \(T\), consisting of vectors in \(\mathbb{R}^{d}.\) Elements of a batch are computed independently, ideally in parallel. For example, \(\boldsymbol{\mathsf{X}}_{[0, :, :]}\) consist of a batch of all vectors at \(t = 0.\) Similarly, \(\boldsymbol{\mathsf{X}}_{[:, 0, :]}\) is one instance of a sequence of vectors. At each step, the layer returns the state vector of shape \((B, h).\) These are stacked to get a tensor of shape \((T, B,h)\) consistent with the input. RNN cell computation can be written as:

outs = []
for t in range(T):
    h = torch.tanh(x[t] @ self.U + h @ self.W + self.b)
    outs.append(h)

Here the state his updated at each step, and the output vector is also set to h at each step. The initialization of the state vector is not shown, but we typically set it to zero when not specified. This leads us to the required methods in the base class below. But first, let us define the expected “shape” of an RNN unit.

Base RNN

A recurrent unit is any function that iteratively updates a state based on new sequence input, it may haver other layers for downstream processing at each time step, so we also return an output tensor. We set the following guidelines:

  1. A recurrent unit must have (inputs_dim, hidden_dim, **kwargs) as arguments.

  2. It’s forward signature is (x, state=None) where x is “sequence first”, i.e. \((T, B, d)\).

  3. It’s forward return format is outs, state where state has the expected format as input for the forward function and outs has shape \((T, B, h)\).

The parameter (d, h) for an RNN can be read off as transforming the each sequence element from \(\mathbb{R}^d\) to \(\mathbb{R}^h.\) Next, an implementation with expected inputs \((T, B, d)\) is already linear layer friendly. This can be interpreted as processing an entire batch at each time step, instead of entire sequences per batch. Hence, we choose this over the more intuitive “batch first” shape \((B, T, d)\).

Finally, while the output has to be of shape \((T, B, h)\), the state is more arbitrary. For example, it can be a tuple of tensors (h, c). As such, we can call the unit with either (x) or (x, state=(h0, c0)). This latter is useful for setting up a warmup state, or continuing inference with another input sequence. The only constraint is that the states are consistently formatted in all parts of a specific implementation. For example, if state is the current output state, then the unit can be called next with (x, state=state) without errors.

Hide code cell outputs
import torch
import torch.nn as nn
import numpy as np
import random

RANDOM_SEED = 0
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
MPS = torch.backends.mps.is_available()
CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda:0") if CUDA else torch.device("mps") if MPS else torch.device("cpu")
class RNNBase(nn.Module):
    """Base class for recurrent units, e.g. RNN, LSTM, GRU, etc."""
    def __init__(self, inputs_dim: int, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.inputs_dim = inputs_dim
        
    def init_state(self, x):
        raise NotImplementedError
    
    def compute(self, x, state):
        raise NotImplementedError

    def forward(self, x, state=None):
        state = self.init_state(x) if state is None else state
        outs, state = self.compute(x, state)
        return outs, state

Implementing the basic RNN cell:

class RNN(RNNBase):
    """Simple RNN unit."""
    def __init__(self, inputs_dim: int, hidden_dim: int):
        super().__init__(inputs_dim, hidden_dim)
        self.U = nn.Parameter(torch.randn(inputs_dim, hidden_dim) / np.sqrt(inputs_dim))
        self.W = nn.Parameter(torch.randn(hidden_dim, hidden_dim) / np.sqrt(hidden_dim))
        self.b = nn.Parameter(torch.zeros(hidden_dim))

    def init_state(self, x):
        B = x.shape[1]
        h = torch.zeros(B, self.hidden_dim, device=x.device)
        return h
    
    def compute(self, x, state):
        h = state
        T = x.shape[0]
        outs = []
        for t in range(T):
            h = torch.tanh(x[t] @ self.U + h @ self.W + self.b)
            outs.append(h)
        return torch.stack(outs), h

Remark. It’s important to note that our RNN does not store state outside of forward pass.

Shapes test, i.e. \((T, B, d)\) to \((T, B, h)\) for a network with params \((d, h)\):

B, T, d, h = 32, 10, 30, 5
x = torch.randn(T, B, d)
rnn = RNN(d, h)
outs, state = rnn(x)
assert outs.shape == (T, B, h)
assert state.shape == (B, h)
assert torch.abs(outs[-1] - state).max() < 1e-8

Remark. The PyTorch RNN module has a similar API:

B, T, d, h = 32, 10, 30, 5
rnn_torch = nn.RNN(d, h)
outs, state = rnn_torch(x)
assert outs.shape == (T, B, h)
assert state.shape == (1, B, h)
assert torch.abs(outs[-1] - state).max() < 1e-8

Correctness:

for name, p in rnn.named_parameters():
    if name == "b":
        p.data.fill_(0.0)
    else:
        p.data.fill_(1.0)

for name, p in rnn_torch.named_parameters():
    if "bias" in name:
        p.data.fill_(0.0)
    else:
        p.data.fill_(1.0)

error = torch.abs(rnn(x)[0] - rnn_torch(x)[0]).max()
print(error)
assert error < 1e-6
tensor(0., grad_fn=<MaxBackward1>)