Gated Recurrent Unit (GRU)

Recall that the hidden state in the LSTM cell is essentially the internal state filtered based on the current input (i.e. with the output gate \(\boldsymbol{\Gamma}^o_t\)). Modifications such as peephole connections have been proposed to include the pure internal cell state in the computation. A further simplification is the GRU [CvMG+14] which combines the feedback mechanism into a single hidden state vector that is appropriately gated. This significantly reduces the number of parameters compared to LSTM with similar performance.

Gating mechanism

Similar to an LSTM cell, GRU uses gates to control update and reset of memory over time:

Gate

Symbol

Controls

Reset

\(\boldsymbol{\Gamma}^r\)

How much of the previous state is incorporated into the new state

Update

\(\boldsymbol{\Gamma}^z\)

How much of the old state is retained, versus the candidate state

This is similar to the LSTM cell with the input gate corresponding to the reset gate and the forget gate corresponding to the update gate. There is no output gate; the current state is calculated as a weighted average of the previous state and the candidate state.

GRU equations

GRU computation is defined as follows. Let \(h\) be the width of the GRU and \(B\) be the batch size. Then all entries of the gates in \((0, 1) \subset \mathbb{R}\) which can be interpreted as a \(B \times h\) smooth switch:

\[\begin{split} \begin{aligned} {\boldsymbol{\Gamma}^r_t} & =\sigma(\boldsymbol{\mathsf{X}}_t \boldsymbol{\mathsf{U}}^{{r}}+\boldsymbol{\mathsf{H}}_{t-1} \boldsymbol{\mathsf{W}}^{{r}}+\boldsymbol{\mathsf{b}}^{{r}}) \\ {\boldsymbol{\Gamma}^z_t} & =\sigma(\boldsymbol{\mathsf{X}}_t \boldsymbol{\mathsf{U}}^{{z}}+\boldsymbol{\mathsf{H}}_{t-1} \boldsymbol{\mathsf{W}}^{{z}}+\boldsymbol{\mathsf{b}}^{{z}}) \end{aligned} \end{split}\]

The hidden state \({\boldsymbol{\mathsf{C}}}_t\) with shape \(B \times h\) is updated as

\[\begin{split} \begin{aligned} \tilde{\boldsymbol{\mathsf{H}}}_t &= \tanh(\boldsymbol{\mathsf{X}}_t \boldsymbol{\mathsf{U}}^{{h}}+(\boldsymbol{\Gamma}^r_t \odot \boldsymbol{\mathsf{H}}_{t-1}) \boldsymbol{\mathsf{W}}^{{h}}+\boldsymbol{\mathsf{b}}^{{h}}) \\ \boldsymbol{\mathsf{H}}_t &= {\boldsymbol{\Gamma}^z_t} \odot \boldsymbol{\mathsf{H}}_{t - 1} + (1 - {\boldsymbol{\Gamma}^z_t}) \odot \tilde{\boldsymbol{\mathsf{H}}}_t \end{aligned} \end{split}\]

where \(\tilde{\boldsymbol{\mathsf{H}}}_t\) is called the candidate hidden state. Applying \(\tanh\) ensures that elements of \({\boldsymbol{\mathsf{H}}}_t\) are in \((-1, 1).\) The computation is illustrated in Fig. 62:


../../../_images/05-gru.svg

Fig. 62 Computing the hidden state in a GRU model. Source

First, let us look at the reset gate. When \({\boldsymbol{\Gamma}^r_t} = 1\), the entire hidden state is used to calculate the candidate hidden state. On the other hand, if \({\boldsymbol{\Gamma}^r_t} = 0\), then the hidden state is reset and we start with a candidate state that depends precisely on the current input. However, the update gate can still ignore the candidate state with \({\boldsymbol{\Gamma}^z_t} = 1\), so that past inputs will potentially still have an effect on future outputs, i.e. \(\boldsymbol{\mathsf{h}}_t \approx \boldsymbol{\mathsf{h}}_{t - 1}.\) Similarly, the GRU can accrue information across many time steps this way by keeping the reset and update gates open. Once it closes, the accumulated information in the candidate state takes effect. In general, this is how long-term dependencies are handled by the unit. The gradient for \(\boldsymbol{\mathsf{h}}_{t-1}\) can be calculated by following nodes that depend on it, there are four such paths, but one path involving \(\frac{\partial{\boldsymbol{\mathsf{h}}_t}}{\partial{\boldsymbol{\mathsf{h}}_{t-1}}}\) only involves scaling with \({\boldsymbol{\Gamma}^z_t} \, \odot\) instead of matrix multiplication.


Code implementation

class GRU(RNNBase):
    def __init__(self, inputs_dim: int, hidden_dim: int):
        super().__init__(inputs_dim, hidden_dim)
        self.hidden_dim = hidden_dim
        self.inputs_dim = inputs_dim
        self.R = nn.Linear(inputs_dim + hidden_dim, hidden_dim)
        self.Z = nn.Linear(inputs_dim + hidden_dim, hidden_dim)
        self.G = nn.Linear(inputs_dim + hidden_dim, hidden_dim)
    
    def init_state(self, x):
        B = x.shape[1]
        return torch.zeros(B, self.hidden_dim, device=x.device)

    def _step(self, x_t, state):
        h = state
        x_gate = torch.cat([x_t, h], dim=1)
        r = torch.sigmoid(self.R(x_gate))
        z = torch.sigmoid(self.Z(x_gate))
        g = torch.tanh(self.G(torch.cat([x_t, r * h], dim=1)))
        h = z * h + (1 - z) * g
        return h, h

    def compute(self, x, state):
        T = x.shape[0]
        outs = []
        for t in range(T):
            out, state = self._step(x[t], state)
            outs.append(out)
        return torch.stack(outs), state

Shapes test:

B, T, d, h = 32, 5, 10, 20
x = torch.randn(T, B, d)
gru = GRU(d, h)
outs, H = gru(x)    # same with PyTorch: 
                    # https://pytorch.org/docs/stable/generated/torch.nn.GRU.html

assert outs.shape == (T, B, h)
assert H.shape == (B, h)
assert all(H[0] == outs[-1][0])

Model training

from torch.utils.data import random_split

data, tokenizer = TimeMachine().build()
T = 30
BATCH_SIZE = 128
VOCAB_SIZE = tokenizer.vocab_size

dataset = SequenceDataset(data, seq_len=T, vocab_size=VOCAB_SIZE)
train_dataset, valid_dataset = random_split(dataset, [0.80, 0.20])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)    # also sampled

print("preds per epoch")
print("train:", f"{len(train_loader) * BATCH_SIZE * T: .3e}")
print("valid:", f"{len(valid_loader) * BATCH_SIZE * T: .3e}")
preds per epoch
train:  4.182e+06
valid:  1.048e+06
from tqdm.notebook import tqdm

DEVICE = "cpu"
LR = 0.01
EPOCHS = 5
MAX_NORM = 1.0

model = LanguageModel(GRU)(VOCAB_SIZE, 64, VOCAB_SIZE)
model.to(DEVICE)
optim = torch.optim.Adam(model.parameters(), lr=LR)

train_losses = []
valid_losses = []
for e in tqdm(range(EPOCHS)):
    for t, (x, y) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x, y = x.to(DEVICE), y.to(DEVICE)
        loss = train_step(model, optim, x, y, MAX_NORM)
        train_losses.append(loss)

        if t % 5 == 0:
            xv, yv = next(iter(valid_loader))
            xv, yv = xv.to(DEVICE), yv.to(DEVICE)
            valid_losses.append(valid_step(model, xv, yv))
Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline
backend_inline.set_matplotlib_formats("svg")

plt.figure(figsize=(10, 4))
plt.plot(train_losses, label="train")
plt.plot(np.array(range(1, len(valid_losses) + 1)) * 5, valid_losses, label="valid")
plt.grid(linestyle="dotted", alpha=0.6)
plt.ylabel("loss")
plt.xlabel("step")
plt.legend();
../../../_images/71b62582e53a636d01be3789aa34f98b403aced87ee26a013e7a1487ff467678.svg
np.array(valid_losses[-50:]).mean()
1.3439158058166505

Text generation

textgen = TextGenerator(model, tokenizer, device="cpu")
s = [textgen.predict("thank y", num_preds=2, temperature=0.4) for i in range(20)]
(np.array(s) == "thank you").mean()
0.95
warmup = "mr williams i underst"
text = []
temperature = []
for i in range(1, 6):
    t = 0.20 * i
    s = textgen.predict(warmup, num_preds=100, temperature=t)
    text.append(s)
    temperature.append(t)
Hide code cell source
import pandas as pd
from IPython.display import display
pd.set_option("display.max_colwidth", None)
df = pd.DataFrame({"temp": [f"{t:.1f}" for t in temperature], "text": text})
df = df.style.set_properties(**{"text-align": "left"})
display(df)
  temp text
0 0.2 mr williams i understand the time traveller the time traveller the palace of the start of the time traveller the palace o
1 0.4 mr williams i understould find the lamp to the little people were a passed the bright race of face the force and said the
2 0.6 mr williams i understards me fiesh in the only to the sun seen and sure me out to the time traveller as i carks of the th
3 0.8 mr williams i understood my diront of the rose they reed perhaps the ancessments was s the behind madzes to the time trav
4 1.0 mr williams i understail his vate and they nights began to make what continuar when it feeckean there were too absender t