RNN language model
Our goal in this section is to train a character-level RNN language model to predict the next token at each step with varying-length context. Hence, during training, our model predicts on each time-step (Fig. 57). The language model below is simply an RNN cell with an attached logits layer applied at each step.
Fig. 57 Character-level RNN language model for predicting the next character at each step. Source
To implement a language model, we simply attach a linear layer on the RNN unit to compute logits.
The linear layer performs matrix multiplication on the rightmost dimension of outs
which contains the value of the state vector at each time step. Thus, as shown in Fig. 57 we have \(T\) predictions with increasing context size[1] \(1, 2, \ldots, T.\)
import torch
import torch.nn as nn
from typing import Type
from functools import partial
class RNNLanguageModel(nn.Module):
def __init__(self,
cell: Type[RNNBase],
inputs_dim: int,
hidden_dim: int,
vocab_size: int,
**kwargs
):
super().__init__()
self.cell = cell(inputs_dim, hidden_dim, **kwargs)
self.linear = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, state=None, return_state=False):
outs, state = self.cell(x, state)
outs = self.linear(outs) # (T, B, H) -> (T, B, C)
return outs if not return_state else (outs, state)
LanguageModel = lambda cell: partial(RNNLanguageModel, cell)
Character sequences dataset
Our dataset consists of \(T\) input-output pairs of characters shifted one time step:
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
class SequenceDataset(Dataset):
def __init__(self, data: torch.Tensor, seq_len: int, vocab_size: int):
super().__init__()
self.data = data
self.seq_len = seq_len
self.vocab_size = vocab_size
def __getitem__(self, i):
c = self.data[i: i + self.seq_len + 1]
x, y = c[:-1], c[1:]
x = F.one_hot(x, num_classes=self.vocab_size).float()
return x, y
def __len__(self):
return len(self.data) - self.seq_len
Training on the Time Machine text:
Show code cell outputs
import re
import os
import torch
import requests
from collections import Counter
from typing import Union, Optional, TypeVar, List
from pathlib import Path
DATA_DIR = Path("./data")
DATA_DIR.mkdir(exist_ok=True)
T = TypeVar("T")
ScalarOrList = Union[T, List[T]]
class Vocab:
def __init__(self,
text: str,
min_freq: int = 0,
reserved_tokens: Optional[List[str]] = None,
preprocess: bool = True
):
text = self.preprocess(text) if preprocess else text
tokens = list(text)
counter = Counter(tokens)
reserved_tokens = reserved_tokens or []
self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
self.itos = [self.unk_token] + reserved_tokens + [tok for tok, f in filter(lambda tokf: tokf[1] >= min_freq, self.token_freqs)]
self.stoi = {tok: idx for idx, tok in enumerate(self.itos)}
def __len__(self):
return len(self.itos)
def __getitem__(self, tokens: ScalarOrList[str]) -> ScalarOrList[int]:
if isinstance(tokens, str):
return self.stoi.get(tokens, self.unk)
else:
return [self.__getitem__(tok) for tok in tokens]
def to_tokens(self, indices: ScalarOrList[int]) -> ScalarOrList[str]:
if isinstance(indices, int):
return self.itos[indices]
else:
return [self.itos[int(index)] for index in indices]
def preprocess(self, text: str):
return re.sub("[^A-Za-z]+", " ", text).lower().strip()
@property
def unk_token(self) -> str:
return "▮"
@property
def unk(self) -> int:
return self.stoi[self.unk_token]
@property
def tokens(self) -> List[int]:
return self.itos
class Tokenizer:
def __init__(self, vocab: Vocab):
self.vocab = vocab
def tokenize(self, text: str) -> List[str]:
UNK = self.vocab.unk_token
tokens = self.vocab.stoi.keys()
return [c if c in tokens else UNK for c in list(text)]
def encode(self, text: str) -> torch.Tensor:
x = self.vocab[self.tokenize(text)]
return torch.tensor(x, dtype=torch.int64)
def decode(self, indices: Union[ScalarOrList[int], torch.Tensor]) -> str:
return "".join(self.vocab.to_tokens(indices))
@property
def vocab_size(self) -> int:
return len(self.vocab)
class TimeMachine:
def __init__(self, download=False, path=None):
DEFAULT_PATH = str((DATA_DIR / "time_machine.txt").absolute())
self.filepath = path or DEFAULT_PATH
if download or not os.path.exists(self.filepath):
self._download()
def _download(self):
url = "https://www.gutenberg.org/cache/epub/35/pg35.txt"
print(f"Downloading text from {url} ...", end=" ")
response = requests.get(url, stream=True)
response.raise_for_status()
print("OK!")
with open(self.filepath, "wb") as output:
output.write(response.content)
def _load_text(self):
with open(self.filepath, "r") as f:
text = f.read()
s = "*** START OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***"
e = "*** END OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***"
return text[text.find(s) + len(s): text.find(e)]
def build(self, vocab: Optional[Vocab] = None):
self.text = self._load_text()
vocab = vocab or Vocab(self.text)
tokenizer = Tokenizer(vocab)
encoded_text = tokenizer.encode(vocab.preprocess(self.text))
return encoded_text, tokenizer
from torch.utils.data import random_split
def collate_fn(batch):
"""Transforming the data to sequence-first format."""
x, y = zip(*batch)
x = torch.stack(x, 1) # (T, B, vocab_size)
y = torch.stack(y, 1) # (T, B)
return x, y
data, tokenizer = TimeMachine().build()
VOCAB_SIZE = tokenizer.vocab_size
dataset = SequenceDataset(data, seq_len=10, vocab_size=VOCAB_SIZE)
train_dataset, valid_dataset = random_split(dataset, [0.80, 0.20])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
The batch index (i.e. starting point) is shuffled, but the ordering in each sequence is intact:
x, y = next(iter(train_loader))
a, T = 1, dataset.seq_len
x_chars = tokenizer.decode(torch.argmax(x[:, a], dim=1)) # inputs are one-hot
y_chars = tokenizer.decode(y[:, a])
for i in range(T):
print(f"{x_chars[i]} --> {y_chars[i]}")
y -->
--> o
o --> w
w --> n
n -->
--> e
e --> x
x --> p
p --> e
e --> n
print(x.shape, y.shape)
print("inputs:", torch.argmax(x[:, 0], dim=-1))
print("target:", y[:, 0])
torch.Size([10, 32, 28]) torch.Size([10, 32])
inputs: tensor([ 4, 6, 11, 1, 17, 9, 2, 6, 1, 5])
target: tensor([ 6, 11, 1, 17, 9, 2, 6, 1, 5, 1])
PyTorch F.cross_entropy
expects input (B, C, T)
and target (B, T)
:
import torch.nn.functional as F
x, y = next(iter(train_loader))
model = LanguageModel(RNN)(VOCAB_SIZE, 5, VOCAB_SIZE)
loss = F.cross_entropy(model(x).permute(1, 2, 0), y.transpose(0, 1))
loss
tensor(3.3744, grad_fn=<NllLoss2DBackward0>)