RNN language model

Our goal in this section is to train a character-level RNN language model to predict the next token at each step with varying-length context. Hence, during training, our model predicts on each time-step (Fig. 57). The language model below is simply an RNN cell with an attached logits layer applied at each step.

../../../_images/04-char-rnn.svg

Fig. 57 Character-level RNN language model for predicting the next character at each step. Source

To implement a language model, we simply attach a linear layer on the RNN unit to compute logits. The linear layer performs matrix multiplication on the rightmost dimension of outs which contains the value of the state vector at each time step. Thus, as shown in Fig. 57 we have \(T\) predictions with increasing context size[1] \(1, 2, \ldots, T.\)

import torch
import torch.nn as nn
from typing import Type
from functools import partial


class RNNLanguageModel(nn.Module):
    def __init__(self, 
        cell: Type[RNNBase],
        inputs_dim: int,
        hidden_dim: int,
        vocab_size: int,
        **kwargs
    ):
        super().__init__()
        self.cell = cell(inputs_dim, hidden_dim, **kwargs)
        self.linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, state=None, return_state=False):
        outs, state = self.cell(x, state)
        outs = self.linear(outs)    # (T, B, H) -> (T, B, C)
        return outs if not return_state else (outs, state)


LanguageModel = lambda cell: partial(RNNLanguageModel, cell)

Character sequences dataset

Our dataset consists of \(T\) input-output pairs of characters shifted one time step:

import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

class SequenceDataset(Dataset):
    def __init__(self, data: torch.Tensor, seq_len: int, vocab_size: int):
        super().__init__()
        self.data = data
        self.seq_len = seq_len
        self.vocab_size = vocab_size

    def __getitem__(self, i):
        c = self.data[i: i + self.seq_len + 1]
        x, y = c[:-1], c[1:]
        x = F.one_hot(x, num_classes=self.vocab_size).float()
        return x, y
    
    def __len__(self):
        return len(self.data) - self.seq_len

Training on the Time Machine text:

Hide code cell outputs
import re
import os
import torch
import requests
from collections import Counter
from typing import Union, Optional, TypeVar, List

from pathlib import Path

DATA_DIR = Path("./data")
DATA_DIR.mkdir(exist_ok=True)


T = TypeVar("T")
ScalarOrList = Union[T, List[T]]


class Vocab:
    def __init__(self, 
        text: str, 
        min_freq: int = 0, 
        reserved_tokens: Optional[List[str]] = None,
        preprocess: bool = True
    ):
        text = self.preprocess(text) if preprocess else text
        tokens = list(text)
        counter = Counter(tokens)
        reserved_tokens = reserved_tokens or []
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        self.itos = [self.unk_token] + reserved_tokens + [tok for tok, f in filter(lambda tokf: tokf[1] >= min_freq, self.token_freqs)]
        self.stoi = {tok: idx for idx, tok in enumerate(self.itos)}

    def __len__(self):
        return len(self.itos)
    
    def __getitem__(self, tokens: ScalarOrList[str]) -> ScalarOrList[int]:
        if isinstance(tokens, str):
            return self.stoi.get(tokens, self.unk)
        else:
            return [self.__getitem__(tok) for tok in tokens]

    def to_tokens(self, indices: ScalarOrList[int]) -> ScalarOrList[str]:
        if isinstance(indices, int):
            return self.itos[indices]
        else:
            return [self.itos[int(index)] for index in indices]
            
    def preprocess(self, text: str):
        return re.sub("[^A-Za-z]+", " ", text).lower().strip()

    @property
    def unk_token(self) -> str:
        return "▮"

    @property
    def unk(self) -> int:
        return self.stoi[self.unk_token]

    @property
    def tokens(self) -> List[int]:
        return self.itos


class Tokenizer:
    def __init__(self, vocab: Vocab):
        self.vocab = vocab

    def tokenize(self, text: str) -> List[str]:
        UNK = self.vocab.unk_token
        tokens = self.vocab.stoi.keys()
        return [c if c in tokens else UNK for c in list(text)]

    def encode(self, text: str) -> torch.Tensor:
        x = self.vocab[self.tokenize(text)]
        return torch.tensor(x, dtype=torch.int64)

    def decode(self, indices: Union[ScalarOrList[int], torch.Tensor]) -> str:
        return "".join(self.vocab.to_tokens(indices))

    @property
    def vocab_size(self) -> int:
        return len(self.vocab)


class TimeMachine:
    def __init__(self, download=False, path=None):
        DEFAULT_PATH = str((DATA_DIR / "time_machine.txt").absolute())
        self.filepath = path or DEFAULT_PATH
        if download or not os.path.exists(self.filepath):
            self._download()
        
    def _download(self):
        url = "https://www.gutenberg.org/cache/epub/35/pg35.txt"
        print(f"Downloading text from {url} ...", end=" ")
        response = requests.get(url, stream=True)
        response.raise_for_status()
        print("OK!")
        with open(self.filepath, "wb") as output:
            output.write(response.content)
        
    def _load_text(self):
        with open(self.filepath, "r") as f:
            text = f.read()
        s = "*** START OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***"
        e = "*** END OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***"
        return text[text.find(s) + len(s): text.find(e)]
    
    def build(self, vocab: Optional[Vocab] = None):
        self.text = self._load_text()
        vocab = vocab or Vocab(self.text)
        tokenizer = Tokenizer(vocab)
        encoded_text = tokenizer.encode(vocab.preprocess(self.text))
        return encoded_text, tokenizer
from torch.utils.data import random_split

def collate_fn(batch):
    """Transforming the data to sequence-first format."""
    x, y = zip(*batch)
    x = torch.stack(x, 1)      # (T, B, vocab_size)
    y = torch.stack(y, 1)      # (T, B)
    return x, y


data, tokenizer = TimeMachine().build()
VOCAB_SIZE = tokenizer.vocab_size
dataset = SequenceDataset(data, seq_len=10, vocab_size=VOCAB_SIZE)
train_dataset, valid_dataset = random_split(dataset, [0.80, 0.20])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

The batch index (i.e. starting point) is shuffled, but the ordering in each sequence is intact:

x, y = next(iter(train_loader))

a, T = 1, dataset.seq_len
x_chars = tokenizer.decode(torch.argmax(x[:, a], dim=1))     # inputs are one-hot
y_chars = tokenizer.decode(y[:, a])
for i in range(T):
    print(f"{x_chars[i]} --> {y_chars[i]}")
y -->  
  --> o
o --> w
w --> n
n -->  
  --> e
e --> x
x --> p
p --> e
e --> n
print(x.shape, y.shape)
print("inputs:", torch.argmax(x[:, 0], dim=-1))
print("target:", y[:, 0])
torch.Size([10, 32, 28]) torch.Size([10, 32])
inputs: tensor([ 4,  6, 11,  1, 17,  9,  2,  6,  1,  5])
target: tensor([ 6, 11,  1, 17,  9,  2,  6,  1,  5,  1])

PyTorch F.cross_entropy expects input (B, C, T) and target (B, T):

import torch.nn.functional as F

x, y = next(iter(train_loader))
model = LanguageModel(RNN)(VOCAB_SIZE, 5, VOCAB_SIZE)
loss = F.cross_entropy(model(x).permute(1, 2, 0), y.transpose(0, 1))
loss
tensor(3.3744, grad_fn=<NllLoss2DBackward0>)